From 9c27a505b9314f956a26e909c32041f897a78d33 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 23 Mar 2024 11:48:30 +0800 Subject: [PATCH] Handle None arg --- ipex_hijack/__init__.py | 8 ++++++++ scripts/ipex_enhance.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ipex_hijack/__init__.py b/ipex_hijack/__init__.py index e0472f9..9d61052 100644 --- a/ipex_hijack/__init__.py +++ b/ipex_hijack/__init__.py @@ -15,3 +15,11 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def asfp16(v): + return v.to(torch.half) if v is not None else None + + +def astype(v, type): + return v.to(type) if v is not None else None \ No newline at end of file diff --git a/scripts/ipex_enhance.py b/scripts/ipex_enhance.py index a9b371a..9f3b2a6 100644 --- a/scripts/ipex_enhance.py +++ b/scripts/ipex_enhance.py @@ -1,6 +1,6 @@ from modules import script_callbacks, devices from modules.sd_hijack_utils import CondFunc -from ipex_hijack import log +from ipex_hijack import log, asfp16 from ipex_hijack.controlnet import apply_controlnet_hijacks @@ -34,7 +34,7 @@ def apply_general_hijacks(): # IPEX: incorrect batch_norm result with XPU fp32, downcast to fp16 instead # TODO: file an issue to IPEX CondFunc('torch.nn.functional.batch_norm', - lambda orig_func, input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: orig_func(input.half(), running_mean.half(), running_var.half(), weight=weight.half() if weight is not None else None, bias=bias.half() if bias is not None else None, training=training, momentum=momentum, eps=eps).to(input.dtype), + lambda orig_func, input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: orig_func(input.half(), asfp16(running_mean), asfp16(running_var), weight=asfp16(weight), bias=asfp16(bias), training=training, momentum=momentum, eps=eps).to(input.dtype), lambda orig_func, input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: input.device.type == 'xpu' and input.dtype == torch.float) # IPEX: incorrect interpolate result with XPU when align_corner=True, move to cpu instead