Skip to content

Commit

Permalink
Merge pull request #44 from radionets-project/fix_tests
Browse files Browse the repository at this point in the history
Fix tests
  • Loading branch information
tgross03 authored Nov 28, 2024
2 parents 05c9809 + d38601d commit 94fe38f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/changes/44.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Fixed random number drawing in tests by changing the location of the seed override
29 changes: 19 additions & 10 deletions pyvisgen/simulation/data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ def simulate_data_set(config, slurm=False, job_id=None, n=None):
out_path.mkdir(parents=True, exist_ok=True)
data = load_bundles(conf["in_path"])

if conf["seed"] is not None:
np.random.seed(conf["seed"])
torch.manual_seed(conf["seed"])

if slurm:
job_id = int(job_id + n * 500)
out = out_path / Path("vis_" + str(job_id) + ".fits")
Expand Down Expand Up @@ -121,6 +117,14 @@ def create_sampling_rc(conf):
dict
contains the observation parameters
"""

global rng

if conf["seed"] is not None:
rng = np.random.default_rng(conf["seed"])
else:
rng = np.random.default_rng()

samp_ops = draw_sampling_opts(conf)
array_layout = layouts.get_array_layout(conf["layout"][0])
half_telescopes = array_layout.x.shape[0] // 2
Expand All @@ -145,27 +149,32 @@ def draw_sampling_opts(conf):
dict
contains randomly drawn observation options
"""

if "rng" not in globals():
global rng
rng = np.random.default_rng(conf["seed"])

angles_ra = np.arange(
conf["fov_center_ra"][0][0], conf["fov_center_ra"][0][1], step=0.1
)
fov_center_ra = np.random.choice(angles_ra)
fov_center_ra = rng.choice(angles_ra)

angles_dec = np.arange(
conf["fov_center_dec"][0][0], conf["fov_center_dec"][0][1], step=0.1
)
fov_center_dec = np.random.choice(angles_dec)
fov_center_dec = rng.choice(angles_dec)
start_time_l = datetime.strptime(conf["scan_start"][0], "%d-%m-%Y %H:%M:%S")
start_time_h = datetime.strptime(conf["scan_start"][1], "%d-%m-%Y %H:%M:%S")
start_times = pd.date_range(start_time_l, start_time_h, freq="1h").strftime(
"%d-%m-%Y %H:%M:%S"
)
scan_start = np.random.choice(
scan_start = rng.choice(
[datetime.strptime(time, "%d-%m-%Y %H:%M:%S") for time in start_times]
)
scan_duration = np.random.randint(
conf["scan_duration"][0], conf["scan_duration"][1]
scan_duration = int(
rng.integers(conf["scan_duration"][0], conf["scan_duration"][1])
)
num_scans = np.random.randint(conf["num_scans"][0], conf["num_scans"][1])
num_scans = int(rng.integers(conf["num_scans"][0], conf["num_scans"][1]))
opts = np.array(
[
conf["mode"],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_conf.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[sampling_options]
mode = "full"
device = "cpu"
seed = 1337
seed = 42
layout = "vla"
img_size = 128
fov_center_ra = [90, 140]
Expand Down
18 changes: 18 additions & 0 deletions tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ def test_create_sampling_rc():
test_opts(samp_ops)


def test_create_sampling_rc_no_seed():
from pyvisgen.simulation.data_set import create_sampling_rc, test_opts

mod_conf = conf.copy()
mod_conf["seed"] = None

samp_ops = create_sampling_rc(mod_conf)
assert len(samp_ops) == 17

test_opts(samp_ops)


def test_vis_loop():
import torch

Expand Down Expand Up @@ -64,3 +76,9 @@ def test_vis_loop():
out = out_path / Path("vis_0.fits")
hdu_list = writer.create_hdu_list(vis_data, obs)
hdu_list.writeto(out, overwrite=True)


def test_simulate_data_set_no_slurm():
from pyvisgen.simulation.data_set import simulate_data_set

simulate_data_set(config)

0 comments on commit 94fe38f

Please sign in to comment.