From 7897998dc091e3f37ae34d7e047752bd7e7409c5 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 2 Jan 2025 12:00:18 -0700 Subject: [PATCH] removed workers tests. max_workers > 1 still not consistently faster. just sampling is, except for macos, but training is not. --- sup3r/preprocessing/batch_queues/abstract.py | 2 +- tests/batch_handlers/test_bh_general.py | 168 ++++++++----------- tests/training/test_train_gan.py | 74 -------- 3 files changed, 67 insertions(+), 177 deletions(-) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index c6a9200d1..c0d418ddc 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -234,7 +234,7 @@ def __iter__(self): def get_batch(self) -> DsetTuple: """Get batch from queue or directly from a ``Sampler`` through ``sample_batch``.""" - if self.mode == 'eager' or self.queue_cap == 0: + if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0: return self.sample_batch() return self.queue.dequeue() diff --git a/tests/batch_handlers/test_bh_general.py b/tests/batch_handlers/test_bh_general.py index 1a6504ef5..616eb1833 100644 --- a/tests/batch_handlers/test_bh_general.py +++ b/tests/batch_handlers/test_bh_general.py @@ -1,14 +1,14 @@ """Smoke tests for batcher objects. Just make sure things run without errors""" - import copy +import os +import time +from tempfile import TemporaryDirectory import numpy as np import pytest from scipy.ndimage import gaussian_filter -from sup3r.preprocessing import ( - BatchHandler, -) +from sup3r.preprocessing import BatchHandler, DataHandler from sup3r.preprocessing.base import Container from sup3r.utilities.pytest.helpers import ( BatchHandlerTesterFactory, @@ -35,8 +35,7 @@ def test_batch_sampling_workers(): max_workers = 1. This does not include enqueueing and dequeueing.""" timer = Timer() - ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m']) - sample_shape = (20, 20, 30) + sample_shape = (100, 100, 30) chunk_shape = ( 2 * sample_shape[0], 2 * sample_shape[1], @@ -44,108 +43,73 @@ def test_batch_sampling_workers(): ) n_obs = 10 max_workers = 10 - n_batches = 10 + n_batches = 50 n_epochs = 3 + chunks = dict(zip(['south_north', 'west_east', 'time'], chunk_shape)) - ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape))) - - batcher = BatchHandler( - [ds], - n_batches=n_batches, - batch_size=n_obs, - sample_shape=sample_shape, - max_workers=max_workers, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - ) - timer.start() - for _ in range(n_epochs): - batches = batcher.sample_batches(n_batches) - _ = [batch.result() for batch in batches] - timer.stop() - parallel_time = timer.elapsed / (n_batches * n_epochs) - batcher.stop() - - batcher = BatchHandler( - [ds], - n_batches=n_batches, - batch_size=n_obs, - sample_shape=sample_shape, - max_workers=1, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - ) - timer.start() - for _ in range(n_epochs): - _ = batcher.sample_batches(n_batches) - timer.stop() - serial_time = timer.elapsed / (n_batches * n_epochs) - batcher.stop() + with TemporaryDirectory() as td: + ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m']) + ds.to_netcdf(os.path.join(td, 'test.nc')) + ds = DataHandler(os.path.join(td, 'test.nc'), chunks=chunks) - print( - 'Elapsed (serial / parallel): {} / {}'.format( - serial_time, parallel_time + batcher = BatchHandler( + [ds], + n_batches=n_batches, + batch_size=n_obs, + sample_shape=sample_shape, + max_workers=max_workers, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, ) - ) - assert serial_time > parallel_time - - -def test_batch_queue_workers(): - """Check that it is faster to queue batches with max_workers > 1 than with - max_workers = 1.""" + timer.start() + queue_time = 0 + for _ in range(n_epochs): + batches = batcher.sample_batches(n_batches) + batches = [batch.result() for batch in batches] + queue_start = time.time() + for batch in batches: + batcher.queue.enqueue(batch) + _ = batcher.queue.dequeue() + queue_time += (time.time() - queue_start) + timer.stop() + parallel_time = timer.elapsed / (n_batches * n_epochs) + parallel_queue_time = queue_time / (n_batches * n_epochs) + batcher.stop() - timer = Timer() - ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m']) - sample_shape = (20, 20, 30) - chunk_shape = ( - 2 * sample_shape[0], - 2 * sample_shape[1], - 2 * sample_shape[-1], - ) - n_obs = 10 - max_workers = 10 - n_batches = 10 - n_epochs = 3 - ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape))) - - batcher = BatchHandler( - [ds], - n_batches=n_batches, - batch_size=n_obs, - sample_shape=sample_shape, - max_workers=max_workers, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - ) - timer.start() - for _ in range(n_epochs): - _ = list(batcher) - timer.stop() - parallel_time = timer.elapsed / (n_batches * n_epochs) - batcher.stop() - - batcher = BatchHandler( - [ds], - n_batches=n_batches, - batch_size=n_obs, - sample_shape=sample_shape, - max_workers=1, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - ) - timer.start() - for _ in range(n_epochs): - _ = list(batcher) - timer.stop() - serial_time = timer.elapsed / (n_batches * n_epochs) - batcher.stop() - - print( - 'Elapsed (serial / parallel): {} / {}'.format( - serial_time, parallel_time + batcher = BatchHandler( + [ds], + n_batches=n_batches, + batch_size=n_obs, + sample_shape=sample_shape, + max_workers=1, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, ) - ) - assert serial_time > parallel_time + timer.start() + queue_time = 0 + for _ in range(n_epochs): + batches = batcher.sample_batches(n_batches) + queue_start = time.time() + for batch in batches: + batcher.queue.enqueue(batch) + _ = batcher.queue.dequeue() + queue_time += time.time() - queue_start + timer.stop() + serial_time = timer.elapsed / (n_batches * n_epochs) + serial_queue_time = queue_time / (n_batches * n_epochs) + batcher.stop() + + print( + 'Elapsed total time (serial / parallel): {} / {}'.format( + serial_time, parallel_time + ) + ) + print( + 'Elapsed queue time (serial / parallel): {} / {}'.format( + serial_queue_time, parallel_queue_time + ) + ) + assert serial_time > parallel_time def test_eager_vs_lazy(): diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 4bf974b8d..954bc3cb2 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -11,7 +11,6 @@ from sup3r.models import Sup3rGan from sup3r.preprocessing import BatchHandler, DataHandler -from sup3r.utilities.utilities import Timer TARGET_COORD = (39.01, -105.15) FEATURES = ['u_100m', 'v_100m'] @@ -169,79 +168,6 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): batch_handler.stop() -def test_train_workers(n_epoch=10): - """Test that model training with max_workers > 1 for the batch queue is - faster than for max_workers = 1.""" - - lr = 5e-5 - train_handler, val_handler = _get_handlers() - timer = Timer() - n_batches = 40 - batch_size = 40 - - Sup3rGan.seed() - model = Sup3rGan( - pytest.S_FP_GEN, - pytest.S_FP_DISC, - learning_rate=lr, - loss='MeanAbsoluteError', - ) - - with tempfile.TemporaryDirectory() as td: - batch_handler = BatchHandler( - train_containers=[train_handler], - val_containers=[val_handler], - sample_shape=(10, 10, 1), - batch_size=batch_size, - s_enhance=2, - t_enhance=1, - n_batches=n_batches, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - max_workers=5, - ) - - model_kwargs = { - 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, - 'n_epoch': n_epoch, - 'weight_gen_advers': 0.0, - 'train_gen': True, - 'train_disc': False, - 'checkpoint_int': 10, - 'out_dir': os.path.join(td, 'test_{epoch}'), - } - - timer.start() - model.train(batch_handler, **model_kwargs) - timer.stop() - parallel_time = timer.elapsed / n_epoch - - batch_handler = BatchHandler( - train_containers=[train_handler], - val_containers=[val_handler], - sample_shape=(10, 10, 1), - batch_size=batch_size, - s_enhance=2, - t_enhance=1, - n_batches=n_batches, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - max_workers=1, - ) - - timer.start() - model.train(batch_handler, **model_kwargs) - timer.stop() - serial_time = timer.elapsed / n_epoch - - print( - 'Elapsed (parallel / serial): {} / {}'.format( - parallel_time, serial_time - ) - ) - assert parallel_time < serial_time - - def test_train_st_weight_update(n_epoch=2): """Test basic spatiotemporal model training with discriminators and adversarial loss updating."""