diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index 248a59ace..1ce813cbe 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -496,6 +496,9 @@ def __init__( self.channels = self.model.channels self.self_condition = self.model.self_condition + if isinstance(image_size, int): + image_size = (image_size, image_size) + assert isinstance(image_size, (tuple, list)) and len(image_size) == 2, 'image size must be a integer or a tuple/list of two integers' self.image_size = image_size self.objective = objective @@ -729,9 +732,9 @@ def ddim_sample(self, shape, return_all_timesteps = False): @torch.inference_mode() def sample(self, batch_size = 16, return_all_timesteps = False): - image_size, channels = self.image_size, self.channels + (h, w), channels = self.image_size, self.channels sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample - return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps) + return sample_fn((batch_size, channels, h, w), return_all_timesteps = return_all_timesteps) @torch.inference_mode() def interpolate(self, x1, x2, t = None, lam = 0.5): @@ -811,7 +814,7 @@ def p_losses(self, x_start, t, noise = None, offset_noise_strength = None): def forward(self, img, *args, **kwargs): b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size - assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + assert h == img_size[0] and w == img_size[1], f'height and width of image must be {img_size}' t = torch.randint(0, self.num_timesteps, (b,), device=device).long() img = self.normalize(img)