diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 814aad4..9fb8c9e 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -64,7 +64,6 @@ def get_valid_subset(self, num_baselines, device): date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device) return ValidBaselineSubset( - baseline_nums, u_start, u_stop, u_valid, @@ -74,13 +73,13 @@ def get_valid_subset(self, num_baselines, device): w_start, w_stop, w_valid, + baseline_nums, date, ) @dataclass() class ValidBaselineSubset: - baseline_nums: torch.tensor u_start: torch.tensor u_stop: torch.tensor u_valid: torch.tensor @@ -90,6 +89,7 @@ class ValidBaselineSubset: w_start: torch.tensor w_stop: torch.tensor w_valid: torch.tensor + baseline_nums: torch.tensor date: torch.tensor def __getitem__(self, i): @@ -350,7 +350,7 @@ def create_rd_grid(self): Returns ------- - 3d array + rd_grid : 3d array Returns a 3d array with every pixel containing a RA and Dec value """ # transform to rad @@ -370,9 +370,10 @@ def create_rd_grid(self): - self.img_size / 2 ) * res + dec - _, R = torch.meshgrid((r, r), indexing="ij") - D, _ = torch.meshgrid((d, d), indexing="ij") + R, _ = torch.meshgrid((r, r), indexing="ij") + _, D = torch.meshgrid((d, d), indexing="ij") rd_grid = torch.cat([R[..., None], D[..., None]], dim=2) + return rd_grid def create_lm_grid(self): @@ -387,17 +388,17 @@ def create_lm_grid(self): Returns ------- - 3d array + lm_grid : 3d array Returns a 3d array with every pixel containing a l and m value """ dec = torch.deg2rad(self.dec) lm_grid = torch.zeros(self.rd.shape, device=self.device, dtype=torch.float64) - lm_grid[:, :, 0] = (torch.cos(self.rd[..., 1]) * torch.sin(self.rd[..., 0])).T - lm_grid[:, :, 1] = ( - torch.sin(self.rd[..., 1]) * torch.cos(dec) - - torch.cos(self.rd[..., 1]) * torch.sin(dec) * torch.cos(self.rd[..., 0]) - ).T + lm_grid[..., 0] = torch.cos(self.rd[..., 1]) * torch.sin(self.rd[..., 0]) + lm_grid[..., 1] = torch.sin(self.rd[..., 1]) * torch.cos(dec) - torch.cos( + self.rd[..., 1] + ) * torch.sin(dec) * torch.cos(self.rd[..., 0]) + return lm_grid def get_baselines(self, times):