Skip to content

Commit

Permalink
Fixes ipadapter on CPU
Browse files Browse the repository at this point in the history
Fixes ipadapter on CPU
  • Loading branch information
Nuullll authored Mar 23, 2024
2 parents 4729cd9 + 0e44576 commit 5b87959
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
4 changes: 4 additions & 0 deletions ipex_hijack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import torch


def log(msg):
Expand All @@ -21,5 +22,8 @@ def asfp16(v):
return v.to(torch.half) if v is not None else None


def asfp32(v):
return v.to(torch.float) if v is not None else None

def astype(v, type):
return v.to(type) if v is not None else None
7 changes: 6 additions & 1 deletion scripts/ipex_enhance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from modules import script_callbacks, devices
from modules.sd_hijack_utils import CondFunc
from ipex_hijack import log, asfp16
from ipex_hijack import log, asfp16, asfp32
from ipex_hijack.controlnet import apply_controlnet_hijacks


Expand Down Expand Up @@ -43,6 +43,11 @@ def apply_general_hijacks():
lambda orig_func, input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False: orig_func(input.cpu(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).to(input.device),
lambda orig_func, input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False: input.device.type == 'xpu' and align_corners)

# Fixes ipadapter on CPU
CondFunc('torch.nn.functional.linear',
lambda orig_func, input, weight, bias: orig_func(input.float(), asfp32(weight), asfp32(bias)).half(),
lambda orig_func, input, weight, bias: input.device.type == 'cpu' and input.dtype == torch.half)

log("Registered hijacks for IPEX")


Expand Down

0 comments on commit 5b87959

Please sign in to comment.