diff --git a/scripts/vsmlrt.py b/scripts/vsmlrt.py index c5f3e52..ef47448 100644 --- a/scripts/vsmlrt.py +++ b/scripts/vsmlrt.py @@ -1,4 +1,4 @@ -__version__ = "3.20.6" +__version__ = "3.20.7" __all__ = [ "Backend", "BackendV2", @@ -87,6 +87,7 @@ class ORT_CUDA: fp16: bool = False use_cuda_graph: bool = False # preview, not supported by all models fp16_blacklist_ops: typing.Optional[typing.Sequence[str]] = None + prefer_nhwc: bool = False # internal backend attributes supports_onnx_serialization: bool = True @@ -2032,6 +2033,17 @@ def _inference( fp16_blacklist_ops=backend.fp16_blacklist_ops ) elif isinstance(backend, Backend.ORT_CUDA): + kwargs = dict() + + version_list = core.ort.Version().get("onnxruntime_version", b"0.0.0").split(b'.') + if len(version_list) != 3: + version = (0, 0, 0) + else: + version = tuple(map(int, version_list)) + + if version >= (1, 18, 0): + kwargs["prefer_nhwc"] = backend.prefer_nhwc + clip = core.ort.Model( clips, network_path, overlap=overlap, tilesize=tilesize, @@ -2043,7 +2055,8 @@ def _inference( fp16=backend.fp16, path_is_serialization=path_is_serialization, use_cuda_graph=backend.use_cuda_graph, - fp16_blacklist_ops=backend.fp16_blacklist_ops + fp16_blacklist_ops=backend.fp16_blacklist_ops, + **kwargs ) elif isinstance(backend, Backend.OV_CPU): version = tuple(map(int, core.ov.Version().get("openvino_version", b"0.0.0").split(b'-')[0].split(b'.')))