From 8ab282c6ac7e1676e0626e0913dcb4bd8aee87c3 Mon Sep 17 00:00:00 2001 From: Newton Sander Date: Mon, 24 Jul 2023 17:12:55 +0200 Subject: [PATCH] Cleaning up scenario __init__ --- docs/examples/plot_full_scenario.py | 12 ++++++------ src/neurotechdevkit/__init__.py | 2 +- src/neurotechdevkit/scenarios/_base.py | 18 +----------------- src/neurotechdevkit/scenarios/_scenario_1.py | 16 ++-------------- src/neurotechdevkit/scenarios/_scenario_2.py | 16 ++-------------- tests/neurotechdevkit/scenarios/test_base.py | 3 --- .../scenarios/test_materials.py | 15 ++++----------- .../neurotechdevkit/scenarios/test_results.py | 4 ++-- 8 files changed, 18 insertions(+), 68 deletions(-) diff --git a/docs/examples/plot_full_scenario.py b/docs/examples/plot_full_scenario.py index d49b9755..28697f9f 100644 --- a/docs/examples/plot_full_scenario.py +++ b/docs/examples/plot_full_scenario.py @@ -67,9 +67,7 @@ class FullScenario(Scenario2D): num_points=1000, ) ] - - def __init__(self, scenario_id: str, material_outline_upsample_factor: int = 8): - super().__init__(scenario_id, material_outline_upsample_factor) + material_outline_upsample_factor = 8 def compile_problem(self, center_frequency) -> stride.Problem: """The problem definition for the scenario.""" @@ -149,15 +147,17 @@ def _fill_mask(mask, start, end, dx): # %% -# ## Creating the scenario -scenario = FullScenario(FullScenario.scenario_id) +# ## Running the scenario + +scenario = FullScenario() # %% # ## Rendering the scenario layout scenario.render_layout() # %% -# ## Running the scenario +# ## Rendering the simulation +scenario.compile_problem(center_frequency=5e5) result = scenario.simulate_steady_state() assert isinstance(result, SteadyStateResult2D) result.render_steady_state_amplitudes(show_material_outlines=False) diff --git a/src/neurotechdevkit/__init__.py b/src/neurotechdevkit/__init__.py index 194ee22b..3b508dfb 100644 --- a/src/neurotechdevkit/__init__.py +++ b/src/neurotechdevkit/__init__.py @@ -60,7 +60,7 @@ def make(scenario_id: str) -> scenarios.Scenario: f"Scenario '{scenario_id}' does not exist. Please refer to documentation" " for the list of provided scenarios." ) - return _scenario_map[scenario_id](scenario_id=scenario_id) # type: ignore + return _scenario_map[scenario_id]() # type: ignore _scenario_map = { diff --git a/src/neurotechdevkit/scenarios/_base.py b/src/neurotechdevkit/scenarios/_base.py index 3bca2c76..28909e9a 100644 --- a/src/neurotechdevkit/scenarios/_base.py +++ b/src/neurotechdevkit/scenarios/_base.py @@ -75,25 +75,9 @@ class Scenario(abc.ABC): target: Target scenario_id: str - material_outline_upsample_factor: int slice_axis: int slice_position: float - - def __init__( - self, - scenario_id: str, - material_outline_upsample_factor: int = 16, - ): - """ - Initialize a new scenario. - - Args: - scenario_id (str): An identifier for the scenario. - material_outline_upsample_factor (int, optional): The factor by which to - upsample the material outline. Defaults to 16. - """ - self.scenario_id = scenario_id - self.material_outline_upsample_factor = material_outline_upsample_factor + material_outline_upsample_factor: int = 16 def render_layout( self, diff --git a/src/neurotechdevkit/scenarios/_scenario_1.py b/src/neurotechdevkit/scenarios/_scenario_1.py index 48c887e6..ab9bae4c 100644 --- a/src/neurotechdevkit/scenarios/_scenario_1.py +++ b/src/neurotechdevkit/scenarios/_scenario_1.py @@ -100,13 +100,7 @@ class Scenario1_2D(Scenario1, Scenario2D): ) ] origin = np.array([0.0, -0.035]) - - def __init__(self, scenario_id: str, material_outline_upsample_factor: int = 8): - """Instantiate Scenario1 with overwritten material_outline_upsample_factor.""" - super().__init__( - scenario_id=scenario_id, - material_outline_upsample_factor=material_outline_upsample_factor, - ) + material_outline_upsample_factor = 8 def compile_problem(self, center_frequency: float) -> stride.Problem: """ @@ -175,13 +169,7 @@ class Scenario1_3D(Scenario1, Scenario3D): ) slice_axis = 1 slice_position = 0.0 - - def __init__(self, scenario_id, material_outline_upsample_factor: int = 8): - """Instantiate Scenario1 with overwritten material_outline_upsample_factor.""" - super().__init__( - scenario_id=scenario_id, - material_outline_upsample_factor=material_outline_upsample_factor, - ) + material_outline_upsample_factor = 8 def compile_problem(self, center_frequency: float) -> stride.Problem: """ diff --git a/src/neurotechdevkit/scenarios/_scenario_2.py b/src/neurotechdevkit/scenarios/_scenario_2.py index 6d4822a6..8b605d1c 100644 --- a/src/neurotechdevkit/scenarios/_scenario_2.py +++ b/src/neurotechdevkit/scenarios/_scenario_2.py @@ -101,13 +101,7 @@ class Scenario2_2D(Scenario2, Scenario2D): num_points=1000, ) ] - - def __init__(self, scenario_id, material_outline_upsample_factor: int = 4): - """Instantiate Scenario2 with overwritten material_outline_upsample_factor.""" - super().__init__( - scenario_id=scenario_id, - material_outline_upsample_factor=material_outline_upsample_factor, - ) + material_outline_upsample_factor = 4 def compile_problem(self, center_frequency: float) -> stride.Problem: """ @@ -183,13 +177,7 @@ class Scenario2_3D(Scenario2, Scenario3D): ) slice_axis = 2 slice_position = 0.0 - - def __init__(self, scenario_id, material_outline_upsample_factor: int = 4): - """Instantiate Scenario2 with overwritten material_outline_upsample_factor.""" - super().__init__( - scenario_id=scenario_id, - material_outline_upsample_factor=material_outline_upsample_factor, - ) + material_outline_upsample_factor = 4 def compile_problem(self, center_frequency: float) -> stride.Problem: """ diff --git a/tests/neurotechdevkit/scenarios/test_base.py b/tests/neurotechdevkit/scenarios/test_base.py index 336543f4..4df1be3d 100644 --- a/tests/neurotechdevkit/scenarios/test_base.py +++ b/tests/neurotechdevkit/scenarios/test_base.py @@ -44,9 +44,6 @@ class ScenarioBaseTester(Scenario): def __init__(self): self.problem = self._compile_problem(center_frequency=5e5) - super().__init__( - scenario_id=self.scenario_id, material_outline_upsample_factor=3 - ) def _compile_problem(self, center_frequency: float) -> stride.Problem: extent = np.array([2.0, 3.0, 4.0]) diff --git a/tests/neurotechdevkit/scenarios/test_materials.py b/tests/neurotechdevkit/scenarios/test_materials.py index 6e3e21b2..a04a15b3 100644 --- a/tests/neurotechdevkit/scenarios/test_materials.py +++ b/tests/neurotechdevkit/scenarios/test_materials.py @@ -13,17 +13,10 @@ def compare_structs(struct1: Struct, struct2: Struct): assert struct1.render_color == struct2.render_color -class BaseScenario(Scenario2D): - """A scenario for testing the materials module.""" - - def __init__(self): - super().__init__(scenario_id="test", material_outline_upsample_factor=16) - - def test_custom_material_property(): """Test that a custom material property is used.""" - class ScenarioWithCustomMaterialProperties(BaseScenario): + class ScenarioWithCustomMaterialProperties(Scenario2D): material_layers = ["brain"] material_properties = { "brain": Material(vp=1600.0, rho=1100.0, alpha=0.0, render_color="#2E86AB") @@ -43,7 +36,7 @@ class ScenarioWithCustomMaterialProperties(BaseScenario): def test_new_material(): """Test that a new material is used.""" - class ScenarioWithCustomMaterial(BaseScenario): + class ScenarioWithCustomMaterial(Scenario2D): material_layers = ["brain", "eye"] material_properties = { "eye": Material(vp=1600.0, rho=1100.0, alpha=0.0, render_color="#2E86AB") @@ -64,7 +57,7 @@ class ScenarioWithCustomMaterial(BaseScenario): def test_material_absorption_is_calculated(): """Test that the material absorption is calculated for a frequency !=500e3.""" - class ScenarioWithBrainMaterial(BaseScenario): + class ScenarioWithBrainMaterial(Scenario2D): material_layers = ["brain"] material_properties = {} @@ -79,7 +72,7 @@ class ScenarioWithBrainMaterial(BaseScenario): def test_unknown_material_without_properties(): """Test that an unknown material without properties raises an error.""" - class ScenarioWithCustomMaterial(BaseScenario): + class ScenarioWithCustomMaterial(Scenario2D): material_layers = ["unknown_material"] material_properties = {} diff --git a/tests/neurotechdevkit/scenarios/test_results.py b/tests/neurotechdevkit/scenarios/test_results.py index 8f9e11e4..866d44e0 100644 --- a/tests/neurotechdevkit/scenarios/test_results.py +++ b/tests/neurotechdevkit/scenarios/test_results.py @@ -108,7 +108,7 @@ def pulsed_data_2d(): @pytest.fixture def a_test_scenario_2d(): """A real 2D scenario that can be saved to disk and reloaded.""" - scenario = scenarios.Scenario1_2D(scenario_id=scenarios.Scenario1_2D.scenario_id) + scenario = scenarios.Scenario1_2D() scenario.add_source( sources.FocusedSource2D( position=np.array([0.02, 0.02]), @@ -125,7 +125,7 @@ def a_test_scenario_2d(): @pytest.fixture def a_test_scenario_3d(): """A real 3D scenario that can be saved to disk and reloaded.""" - scenario = scenarios.Scenario1_3D(scenario_id=scenarios.Scenario1_3D.scenario_id) + scenario = scenarios.Scenario1_3D() scenario.add_source( sources.FocusedSource3D( position=np.array([0.02, 0.02, 0.0]),