Skip to content

Commit

Permalink
Merge pull request #306 from yzx9/main
Browse files Browse the repository at this point in the history
support non-square images
  • Loading branch information
lucidrains authored Apr 13, 2024
2 parents d0c68fc + e09d640 commit 9f8f5bf
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ def __init__(
self.channels = self.model.channels
self.self_condition = self.model.self_condition

if isinstance(image_size, int):
image_size = (image_size, image_size)
assert isinstance(image_size, (tuple, list)) and len(image_size) == 2, 'image size must be a integer or a tuple/list of two integers'
self.image_size = image_size

self.objective = objective
Expand Down Expand Up @@ -729,9 +732,9 @@ def ddim_sample(self, shape, return_all_timesteps = False):

@torch.inference_mode()
def sample(self, batch_size = 16, return_all_timesteps = False):
image_size, channels = self.image_size, self.channels
(h, w), channels = self.image_size, self.channels
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
return sample_fn((batch_size, channels, h, w), return_all_timesteps = return_all_timesteps)

@torch.inference_mode()
def interpolate(self, x1, x2, t = None, lam = 0.5):
Expand Down Expand Up @@ -811,7 +814,7 @@ def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):

def forward(self, img, *args, **kwargs):
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
assert h == img_size[0] and w == img_size[1], f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

img = self.normalize(img)
Expand Down

0 comments on commit 9f8f5bf

Please sign in to comment.