Skip to content

Commit

Permalink
Merge pull request #48 from radionets-project/auto_batchsize
Browse files Browse the repository at this point in the history
Auto batchsize
  • Loading branch information
aknierim authored Jan 21, 2025
2 parents cbac833 + 63d2e84 commit 7325d26
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/changes/48.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Added optional auto scaling for batchsize in vis_loop
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ dependencies:
- pytest
- pytest-cov
- pytest-runner
- pip:
- toma
25 changes: 13 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,26 @@ classifiers = [
requires-python = ">=3.10"

dependencies = [
"numpy",
"astroplan",
"astropy<=6.1.0",
"torch",
"matplotlib",
"click",
"h5py",
"ipython",
"scipy",
"jupyter",
"matplotlib",
"natsort",
"numexpr",
"numpy",
"pandas",
"toml",
"pre-commit",
"pytest",
"pytest-cov",
"jupyter",
"astroplan",
"scipy",
"toma",
"toml",
"torch",
"torch",
"tqdm",
"numexpr",
"click",
"h5py",
"natsort",
"pre-commit",
]

[project.scripts]
Expand Down
86 changes: 81 additions & 5 deletions pyvisgen/simulation/visibility.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass, fields

import torch
from tqdm import tqdm
import toma
from tqdm.autonotebook import tqdm

import pyvisgen.simulation.scan as scan

Expand Down Expand Up @@ -44,12 +45,18 @@ def vis_loop(
num_threads=10,
noisy=True,
mode="full",
batch_size=100,
batch_size="auto",
show_progress=False,
):
torch.set_num_threads(num_threads)
torch._dynamo.config.suppress_errors = True

if not (
isinstance(batch_size, int)
or (isinstance(batch_size, str) and batch_size == "auto")
):
raise ValueError("Expected batch_size to be 'auto' or of type int")

SI = torch.flip(SI, dims=[1])

# define unpolarized sky distribution
Expand Down Expand Up @@ -104,10 +111,78 @@ def vis_loop(
else:
raise ValueError("Unsupported mode!")

batches = torch.arange(bas[:].shape[1]).split(batch_size)
if batch_size == "auto":
batch_size = bas[:].shape[1]

visibilities = toma.explicit.batch(
_batch_loop,
batch_size,
visibilities,
vis_num,
obs,
B,
bas,
lm,
rd,
noisy,
show_progress,
)

return visibilities

if show_progress:
batches = tqdm(batches)

def _batch_loop(
batch_size: int,
visibilities,
vis_num: int,
obs,
B: torch.tensor,
bas,
lm: torch.tensor,
rd: torch.tensor,
noisy: bool | float,
show_progress: bool,
):
"""Main simulation loop of pyvisgen. Computes visibilities
batchwise.
Parameters
----------
batch_size : int
Batch size for loop over Baselines dataclass object.
visibilities : Visibilities
Visibilities dataclass object.
vis_num : int
Number of visibilities.
obs : Observation
Observation class object.
B : torch.tensor
Stokes matrix containing stokes visibilities.
bas : Baselines
Baselines dataclass object.
lm : torch.tensor
lm grid.
rd : torch.tensor
rd grid.
noisy : float or bool
Simulate noise as SEFD with given value. If set to False,
no noise is simulated.
show_progress :
If True, show a progress bar tracking the loop.
Returns
-------
visibilities : Visibilities
Visibilities dataclass object.
"""
batches = torch.arange(bas[:].shape[1]).split(batch_size)
batches = tqdm(
batches,
position=0,
disable=not show_progress,
desc="Computing visibilities",
postfix=f"Batch size: {batch_size}",
)

for p in batches:
bas_p = bas[:][:, p]
Expand Down Expand Up @@ -155,6 +230,7 @@ def vis_loop(

visibilities.add(vis)
del int_values

return visibilities


Expand Down
67 changes: 67 additions & 0 deletions tests/test_simulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import torch
from numpy.testing import assert_raises

from pyvisgen.utils.config import read_data_set_conf

Expand Down Expand Up @@ -78,6 +79,72 @@ def test_vis_loop():
hdu_list.writeto(out, overwrite=True)


def test_vis_loop_batch_size_auto():
import torch

from pyvisgen.simulation.data_set import create_observation
from pyvisgen.simulation.visibility import vis_loop
from pyvisgen.utils.data import load_bundles, open_bundles

bundles = load_bundles(conf["in_path"])
obs = create_observation(conf)
data = open_bundles(bundles[0])
SI = torch.tensor(data[0])[None]

vis_data = vis_loop(
obs,
SI,
noisy=conf["noisy"],
mode=conf["mode"],
batch_size="auto",
)

assert (vis_data[0].SI[0]).dtype == torch.complex128
assert (vis_data[0].SQ[0]).dtype == torch.complex128
assert (vis_data[0].SU[0]).dtype == torch.complex128
assert (vis_data[0].SV[0]).dtype == torch.complex128
assert (vis_data[0].num).dtype == torch.float32
assert (vis_data[0].base_num).dtype == torch.float64
assert torch.is_tensor(vis_data[0].u)
assert torch.is_tensor(vis_data[0].v)
assert torch.is_tensor(vis_data[0].w)
assert (vis_data[0].date).dtype == torch.float64


def test_vis_loop_batch_size_invalid():
import torch

import pyvisgen.fits.writer as writer
from pyvisgen.simulation.data_set import create_observation
from pyvisgen.simulation.visibility import vis_loop
from pyvisgen.utils.data import load_bundles, open_bundles

bundles = load_bundles(conf["in_path"])
obs = create_observation(conf)
data = open_bundles(bundles[0])
SI = torch.tensor(data[0])[None]

assert_raises(
ValueError,
vis_loop,
obs,
SI,
noisy=conf["noisy"],
mode=conf["mode"],
batch_size="abc",
)

assert_raises(
ValueError,
vis_loop,
obs,
SI,
noisy=conf["noisy"],
mode=conf["mode"],
batch_size=20.0,
)


def test_simulate_data_set_no_slurm():
from pyvisgen.simulation.data_set import simulate_data_set

Expand Down

0 comments on commit 7325d26

Please sign in to comment.