diff --git a/ipex_hijack/__init__.py b/ipex_hijack/__init__.py index 9d61052..ae58f0e 100644 --- a/ipex_hijack/__init__.py +++ b/ipex_hijack/__init__.py @@ -1,4 +1,5 @@ import functools +import torch def log(msg): @@ -21,5 +22,8 @@ def asfp16(v): return v.to(torch.half) if v is not None else None +def asfp32(v): + return v.to(torch.float) if v is not None else None + def astype(v, type): return v.to(type) if v is not None else None \ No newline at end of file diff --git a/scripts/ipex_enhance.py b/scripts/ipex_enhance.py index 9f3b2a6..171623a 100644 --- a/scripts/ipex_enhance.py +++ b/scripts/ipex_enhance.py @@ -1,6 +1,6 @@ from modules import script_callbacks, devices from modules.sd_hijack_utils import CondFunc -from ipex_hijack import log, asfp16 +from ipex_hijack import log, asfp16, asfp32 from ipex_hijack.controlnet import apply_controlnet_hijacks @@ -43,6 +43,11 @@ def apply_general_hijacks(): lambda orig_func, input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False: orig_func(input.cpu(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).to(input.device), lambda orig_func, input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False: input.device.type == 'xpu' and align_corners) + # Fixes ipadapter on CPU + CondFunc('torch.nn.functional.linear', + lambda orig_func, input, weight, bias: orig_func(input.float(), asfp32(weight), asfp32(bias)).half(), + lambda orig_func, input, weight, bias: input.device.type == 'cpu' and input.dtype == torch.half) + log("Registered hijacks for IPEX")