Skip to content

Commit

Permalink
Change random number generation to default_rng, Change location of se…
Browse files Browse the repository at this point in the history
…ed init
  • Loading branch information
tgross03 committed Nov 14, 2024
1 parent 05c9809 commit d950f25
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions pyvisgen/simulation/data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ def simulate_data_set(config, slurm=False, job_id=None, n=None):
out_path.mkdir(parents=True, exist_ok=True)
data = load_bundles(conf["in_path"])

if conf["seed"] is not None:
np.random.seed(conf["seed"])
torch.manual_seed(conf["seed"])

# if conf["seed"] is not None:
# global rng
# rng = np.random.default_rng(conf["seed"])
if slurm:
job_id = int(job_id + n * 500)
out = out_path / Path("vis_" + str(job_id) + ".fits")
Expand Down Expand Up @@ -121,6 +120,14 @@ def create_sampling_rc(conf):
dict
contains the observation parameters
"""

global rng

if conf["seed"] is not None:
rng = np.random.default_rng(conf["seed"])
else:
rng = np.random.default_rng()

samp_ops = draw_sampling_opts(conf)
array_layout = layouts.get_array_layout(conf["layout"][0])
half_telescopes = array_layout.x.shape[0] // 2
Expand All @@ -145,27 +152,32 @@ def draw_sampling_opts(conf):
dict
contains randomly drawn observation options
"""

if "rng" not in globals():
global rng
rng = np.random.default_rng(conf["seed"])

angles_ra = np.arange(
conf["fov_center_ra"][0][0], conf["fov_center_ra"][0][1], step=0.1
)
fov_center_ra = np.random.choice(angles_ra)
fov_center_ra = rng.choice(angles_ra)

angles_dec = np.arange(
conf["fov_center_dec"][0][0], conf["fov_center_dec"][0][1], step=0.1
)
fov_center_dec = np.random.choice(angles_dec)
fov_center_dec = rng.choice(angles_dec)
start_time_l = datetime.strptime(conf["scan_start"][0], "%d-%m-%Y %H:%M:%S")
start_time_h = datetime.strptime(conf["scan_start"][1], "%d-%m-%Y %H:%M:%S")
start_times = pd.date_range(start_time_l, start_time_h, freq="1h").strftime(
"%d-%m-%Y %H:%M:%S"
)
scan_start = np.random.choice(
scan_start = rng.choice(
[datetime.strptime(time, "%d-%m-%Y %H:%M:%S") for time in start_times]
)
scan_duration = np.random.randint(
conf["scan_duration"][0], conf["scan_duration"][1]
scan_duration = int(
rng.integers(conf["scan_duration"][0], conf["scan_duration"][1])
)
num_scans = np.random.randint(conf["num_scans"][0], conf["num_scans"][1])
num_scans = int(rng.integers(conf["num_scans"][0], conf["num_scans"][1]))
opts = np.array(
[
conf["mode"],
Expand Down

0 comments on commit d950f25

Please sign in to comment.