Skip to content

Commit

Permalink
Handle None arg
Browse files Browse the repository at this point in the history
Handle None arg
  • Loading branch information
Nuullll authored Mar 23, 2024
2 parents f37d560 + 9c27a50 commit 4729cd9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
8 changes: 8 additions & 0 deletions ipex_hijack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,11 @@ def wrapper(*args, **kwargs):

return wrapper
return decorator


def asfp16(v):
return v.to(torch.half) if v is not None else None


def astype(v, type):
return v.to(type) if v is not None else None
4 changes: 2 additions & 2 deletions scripts/ipex_enhance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from modules import script_callbacks, devices
from modules.sd_hijack_utils import CondFunc
from ipex_hijack import log
from ipex_hijack import log, asfp16
from ipex_hijack.controlnet import apply_controlnet_hijacks


Expand Down Expand Up @@ -34,7 +34,7 @@ def apply_general_hijacks():
# IPEX: incorrect batch_norm result with XPU fp32, downcast to fp16 instead
# TODO: file an issue to IPEX
CondFunc('torch.nn.functional.batch_norm',
lambda orig_func, input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: orig_func(input.half(), running_mean.half(), running_var.half(), weight=weight.half() if weight is not None else None, bias=bias.half() if bias is not None else None, training=training, momentum=momentum, eps=eps).to(input.dtype),
lambda orig_func, input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: orig_func(input.half(), asfp16(running_mean), asfp16(running_var), weight=asfp16(weight), bias=asfp16(bias), training=training, momentum=momentum, eps=eps).to(input.dtype),
lambda orig_func, input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: input.device.type == 'xpu' and input.dtype == torch.float)

# IPEX: incorrect interpolate result with XPU when align_corner=True, move to cpu instead
Expand Down

0 comments on commit 4729cd9

Please sign in to comment.