Skip to content

Commit

Permalink
removed workers tests. max_workers > 1 still not consistently faster.…
Browse files Browse the repository at this point in the history
… just sampling is, except for macos, but training is not.
  • Loading branch information
bnb32 committed Jan 2, 2025
1 parent d516603 commit bd493b1
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 159 deletions.
2 changes: 1 addition & 1 deletion sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def __iter__(self):
def get_batch(self) -> DsetTuple:
"""Get batch from queue or directly from a ``Sampler`` through
``sample_batch``."""
if self.mode == 'eager' or self.queue_cap == 0:
if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0:
return self.sample_batch()
return self.queue.dequeue()

Expand Down
194 changes: 110 additions & 84 deletions tests/batch_handlers/test_bh_general.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Smoke tests for batcher objects. Just make sure things run without errors"""

import copy
import os
import time
from tempfile import TemporaryDirectory

import numpy as np
import pytest
from scipy.ndimage import gaussian_filter

from sup3r.preprocessing import (
BatchHandler,
)
from sup3r.preprocessing import BatchHandler, DataHandler
from sup3r.preprocessing.base import Container
from sup3r.utilities.pytest.helpers import (
BatchHandlerTesterFactory,
Expand All @@ -35,117 +35,143 @@ def test_batch_sampling_workers():
max_workers = 1. This does not include enqueueing and dequeueing."""

timer = Timer()
ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m'])
sample_shape = (20, 20, 30)
sample_shape = (100, 100, 30)
chunk_shape = (
2 * sample_shape[0],
2 * sample_shape[1],
2 * sample_shape[-1],
)
n_obs = 10
max_workers = 10
n_batches = 10
n_batches = 50
n_epochs = 3
chunks = dict(zip(['south_north', 'west_east', 'time'], chunk_shape))

ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape)))
with TemporaryDirectory() as td:
ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m'])
ds.to_netcdf(os.path.join(td, 'test.nc'))
ds = DataHandler(os.path.join(td, 'test.nc'), chunks=chunks)

batcher = BatchHandler(
[ds],
n_batches=n_batches,
batch_size=n_obs,
sample_shape=sample_shape,
max_workers=max_workers,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
)
timer.start()
for _ in range(n_epochs):
batches = batcher.sample_batches(n_batches)
_ = [batch.result() for batch in batches]
timer.stop()
parallel_time = timer.elapsed / (n_batches * n_epochs)
batcher.stop()

batcher = BatchHandler(
[ds],
n_batches=n_batches,
batch_size=n_obs,
sample_shape=sample_shape,
max_workers=1,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
)
timer.start()
for _ in range(n_epochs):
_ = batcher.sample_batches(n_batches)
timer.stop()
serial_time = timer.elapsed / (n_batches * n_epochs)
batcher.stop()
batcher = BatchHandler(
[ds],
n_batches=n_batches,
batch_size=n_obs,
sample_shape=sample_shape,
max_workers=max_workers,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
)
timer.start()
queue_time = 0
for _ in range(n_epochs):
batches = batcher.sample_batches(n_batches)
batches = [batch.result() for batch in batches]
queue_start = time.time()
for batch in batches:
batcher.queue.enqueue(batch)
_ = batcher.queue.dequeue()
queue_time += (time.time() - queue_start)
timer.stop()
parallel_time = timer.elapsed / (n_batches * n_epochs)
parallel_queue_time = queue_time / (n_batches * n_epochs)
batcher.stop()

print(
'Elapsed (serial / parallel): {} / {}'.format(
serial_time, parallel_time
batcher = BatchHandler(
[ds],
n_batches=n_batches,
batch_size=n_obs,
sample_shape=sample_shape,
max_workers=1,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
)
)
assert serial_time > parallel_time
timer.start()
queue_time = 0
for _ in range(n_epochs):
batches = batcher.sample_batches(n_batches)
queue_start = time.time()
for batch in batches:
batcher.queue.enqueue(batch)
_ = batcher.queue.dequeue()
queue_time += time.time() - queue_start
timer.stop()
serial_time = timer.elapsed / (n_batches * n_epochs)
serial_queue_time = queue_time / (n_batches * n_epochs)
batcher.stop()

print(
'Elapsed total time (serial / parallel): {} / {}'.format(
serial_time, parallel_time
)
)
print(
'Elapsed queue time (serial / parallel): {} / {}'.format(
serial_queue_time, parallel_queue_time
)
)
assert serial_time > parallel_time


def test_batch_queue_workers():
"""Check that it is faster to queue batches with max_workers > 1 than with
max_workers = 1."""

timer = Timer()
ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m'])
sample_shape = (20, 20, 30)
sample_shape = (100, 100, 30)
chunk_shape = (
2 * sample_shape[0],
2 * sample_shape[1],
2 * sample_shape[-1],
)
n_obs = 10
max_workers = 10
n_batches = 10
n_batches = 50
n_epochs = 3
ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape)))
chunks = dict(zip(['south_north', 'west_east', 'time'], chunk_shape))

batcher = BatchHandler(
[ds],
n_batches=n_batches,
batch_size=n_obs,
sample_shape=sample_shape,
max_workers=max_workers,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
)
timer.start()
for _ in range(n_epochs):
_ = list(batcher)
timer.stop()
parallel_time = timer.elapsed / (n_batches * n_epochs)
batcher.stop()
with TemporaryDirectory() as td:
ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m'])
ds.to_netcdf(os.path.join(td, 'test.nc'))
ds = DataHandler(os.path.join(td, 'test.nc'), chunks=chunks)

batcher = BatchHandler(
[ds],
n_batches=n_batches,
batch_size=n_obs,
sample_shape=sample_shape,
max_workers=1,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
)
timer.start()
for _ in range(n_epochs):
_ = list(batcher)
timer.stop()
serial_time = timer.elapsed / (n_batches * n_epochs)
batcher.stop()
batcher = BatchHandler(
[ds],
n_batches=n_batches,
batch_size=n_obs,
sample_shape=sample_shape,
max_workers=max_workers,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
)
timer.start()
for _ in range(n_epochs):
_ = list(batcher)
timer.stop()
parallel_time = timer.elapsed / (n_batches * n_epochs)
batcher.stop()

print(
'Elapsed (serial / parallel): {} / {}'.format(
serial_time, parallel_time
batcher = BatchHandler(
[ds],
n_batches=n_batches,
batch_size=n_obs,
sample_shape=sample_shape,
max_workers=1,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
)
)
assert serial_time > parallel_time
timer.start()
for _ in range(n_epochs):
_ = list(batcher)
timer.stop()
serial_time = timer.elapsed / (n_batches * n_epochs)
batcher.stop()

print(
'Elapsed (serial / parallel): {} / {}'.format(
serial_time, parallel_time
)
)
assert serial_time > parallel_time


def test_eager_vs_lazy():
Expand Down
74 changes: 0 additions & 74 deletions tests/training/test_train_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from sup3r.models import Sup3rGan
from sup3r.preprocessing import BatchHandler, DataHandler
from sup3r.utilities.utilities import Timer

TARGET_COORD = (39.01, -105.15)
FEATURES = ['u_100m', 'v_100m']
Expand Down Expand Up @@ -169,79 +168,6 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8):
batch_handler.stop()


def test_train_workers(n_epoch=10):
"""Test that model training with max_workers > 1 for the batch queue is
faster than for max_workers = 1."""

lr = 5e-5
train_handler, val_handler = _get_handlers()
timer = Timer()
n_batches = 40
batch_size = 40

Sup3rGan.seed()
model = Sup3rGan(
pytest.S_FP_GEN,
pytest.S_FP_DISC,
learning_rate=lr,
loss='MeanAbsoluteError',
)

with tempfile.TemporaryDirectory() as td:
batch_handler = BatchHandler(
train_containers=[train_handler],
val_containers=[val_handler],
sample_shape=(10, 10, 1),
batch_size=batch_size,
s_enhance=2,
t_enhance=1,
n_batches=n_batches,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
max_workers=5,
)

model_kwargs = {
'input_resolution': {'spatial': '30km', 'temporal': '60min'},
'n_epoch': n_epoch,
'weight_gen_advers': 0.0,
'train_gen': True,
'train_disc': False,
'checkpoint_int': 10,
'out_dir': os.path.join(td, 'test_{epoch}'),
}

timer.start()
model.train(batch_handler, **model_kwargs)
timer.stop()
parallel_time = timer.elapsed / n_epoch

batch_handler = BatchHandler(
train_containers=[train_handler],
val_containers=[val_handler],
sample_shape=(10, 10, 1),
batch_size=batch_size,
s_enhance=2,
t_enhance=1,
n_batches=n_batches,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
max_workers=1,
)

timer.start()
model.train(batch_handler, **model_kwargs)
timer.stop()
serial_time = timer.elapsed / n_epoch

print(
'Elapsed (parallel / serial): {} / {}'.format(
parallel_time, serial_time
)
)
assert parallel_time < serial_time


def test_train_st_weight_update(n_epoch=2):
"""Test basic spatiotemporal model training with discriminators and
adversarial loss updating."""
Expand Down

0 comments on commit bd493b1

Please sign in to comment.