diff --git a/documentation/Scene_Format_Specification.md b/documentation/Scene_Format_Specification.md index b23906753..f207ed985 100644 --- a/documentation/Scene_Format_Specification.md +++ b/documentation/Scene_Format_Specification.md @@ -124,7 +124,7 @@ The following content was created by running the [simpleplume.py](../demos/simpl "module": "phi.physics.objects" } ], - "type": "CollectiveState", + "type": "StateCollection", "module": "phi.physics.collective" } } diff --git a/phi/physics/collective.py b/phi/physics/collective.py index cc29f987c..f59addbdc 100644 --- a/phi/physics/collective.py +++ b/phi/physics/collective.py @@ -5,7 +5,7 @@ @struct.definition() -class CollectiveState(struct.Struct): +class StateCollection(struct.Struct): def __init__(self, states=None, **kwargs): struct.Struct.__init__(self, **struct.kwargs(locals())) @@ -115,58 +115,61 @@ def shape(self): return struct.map(lambda state: state.shape, self, recursive=False) +CollectiveState = StateCollection + + class CollectivePhysics(Physics): def __init__(self): Physics.__init__(self, {}) self.physics = {} # map from name to Physics - def step(self, collectivestate, dt=1.0, **dependent_states): + def step(self, state_collection, dt=1.0, **dependent_states): assert len(dependent_states) == 0 - if len(collectivestate) == 0: - return collectivestate - unhandled_states = list(collectivestate.states.values()) + if len(state_collection) == 0: + return state_collection + unhandled_states = list(state_collection.states.values()) next_states = {} - partial_next_collectivestate = CollectiveState(next_states) + partial_next_state_collection = StateCollection(next_states) - for sweep in range(len(collectivestate)): + for sweep in range(len(state_collection)): for state in tuple(unhandled_states): physics = self.for_(state) - if self._all_dependencies_fulfilled(physics.blocking_dependencies, collectivestate, partial_next_collectivestate): - next_state = self.substep(state, collectivestate, dt, partial_next_collectivestate=partial_next_collectivestate) + if self._all_dependencies_fulfilled(physics.blocking_dependencies, state_collection, partial_next_state_collection): + next_state = self.substep(state, state_collection, dt, partial_next_state_collection=partial_next_state_collection) assert next_state.name == state.name, "The state name must remain constant during step(). Caused by '%s' on state '%s'." % (type(physics).__name__, state) next_states[next_state.name] = next_state unhandled_states.remove(state) - partial_next_collectivestate = CollectiveState(next_states) + partial_next_state_collection = StateCollection(next_states) if len(unhandled_states) == 0: - ordered_states = [partial_next_collectivestate[state] for state in collectivestate.states] - return partial_next_collectivestate.copied_with(states=ordered_states) + ordered_states = [partial_next_state_collection[state] for state in state_collection.states] + return partial_next_state_collection.copied_with(states=ordered_states) # Error errstr = 'Cyclic blocking_dependencies in simulation: %s' % unhandled_states for state in tuple(unhandled_states): physics = self.for_(state) - state_dict = self._gather_dependencies(physics.blocking_dependencies, collectivestate, {}) + state_dict = self._gather_dependencies(physics.blocking_dependencies, state_collection, {}) errstr += '\nState "%s" with physics "%s" depends on %s' % (state, physics, state_dict) raise AssertionError(errstr) - def substep(self, state, collectivestate, dt, override_physics=None, partial_next_collectivestate=None): + def substep(self, state, state_collection, dt, override_physics=None, partial_next_state_collection=None): physics = self.for_(state) if override_physics is None else override_physics # --- gather dependencies dependent_states = {} - self._gather_dependencies(physics.dependencies, collectivestate, dependent_states) - if partial_next_collectivestate is not None: - self._gather_dependencies(physics.blocking_dependencies, partial_next_collectivestate, dependent_states) + self._gather_dependencies(physics.dependencies, state_collection, dependent_states) + if partial_next_state_collection is not None: + self._gather_dependencies(physics.blocking_dependencies, partial_next_state_collection, dependent_states) # --- execute step --- next_state = physics.step(state, dt, **dependent_states) return next_state - def _gather_dependencies(self, dependencies, collectivestate, result_dict): + def _gather_dependencies(self, dependencies, state_collection, result_dict): for statedependency in dependencies: if statedependency.state_name is not None: - matching_states = collectivestate.find(statedependency.state_name) + matching_states = state_collection.find(statedependency.state_name) else: - matching_states = collectivestate.all_with_tag(statedependency.tag) + matching_states = state_collection.all_with_tag(statedependency.tag) if statedependency.single_state: assert len(matching_states) == 1, 'Dependency %s requires 1 state but found %d' % (statedependency, len(matching_states)) value = matching_states[0] diff --git a/phi/physics/physics.py b/phi/physics/physics.py index 64ba77de9..36036754f 100644 --- a/phi/physics/physics.py +++ b/phi/physics/physics.py @@ -45,7 +45,7 @@ def name(self, name): Names uniquely identify the system represented by this state. All states that represent a configuration of the same system must have the same name. -Names can also be used as a shortcut to reference states (e.g. in CollectiveState or World). +Names can also be used as a shortcut to reference states (e.g. in StateCollection or World). """ if name is None: return '%s_%d' % (self.__class__.__name__.lower(), id(self)) diff --git a/phi/physics/world.py b/phi/physics/world.py index 6abd7709b..82f551213 100644 --- a/phi/physics/world.py +++ b/phi/physics/world.py @@ -1,6 +1,6 @@ """ Worlds are used to manage simulations consisting of multiple states. -A world uses a CollectiveState with CollectivePhysics to resolve state dependencies. +A world uses a StateCollection with CollectivePhysics to resolve state dependencies. Worlds also facilitate referencing states when performing a forward simulation. A default World, called `world` is provided for convenience. @@ -11,7 +11,7 @@ import six -from .collective import CollectiveState +from .collective import StateCollection from .field.effect import Gravity from .physics import Physics, State, Static @@ -122,7 +122,7 @@ def reset(self, batch_size=None, add_default_objects=True): :param batch_size: int or None :param add_default_objects: if True, adds defaults like Gravity """ - self._state = CollectiveState() + self._state = StateCollection() self.physics = self._state.default_physics() self.observers = set() self.batch_size = batch_size @@ -133,7 +133,7 @@ def reset(self, batch_size=None, add_default_objects=True): def state(self): """ Returns the current state of the world. - :return: CollectiveState + :return: StateCollection """ return self._state @@ -148,10 +148,10 @@ def age(self): def state(self, state): """ Sets the current state of the world and informs all observers. - :param state: CollectiveState + :param state: StateCollection """ assert state is not None - assert isinstance(state, CollectiveState) + assert isinstance(state, StateCollection) self._state = state for observer in self.observers: observer(self) diff --git a/phi/tf/world.py b/phi/tf/world.py index dd59b8e53..12e9aba14 100644 --- a/phi/tf/world.py +++ b/phi/tf/world.py @@ -63,8 +63,8 @@ def __init__(self, physics, session, state_in, state_out, dt): self.session = session self.dt = dt - def step(self, collectivestate, dt=1.0, **dependent_states): - result = self.session.run(self.state_out, {self.state_in: collectivestate, self.dt: dt}) + def step(self, state_collection, dt=1.0, **dependent_states): + result = self.session.run(self.state_out, {self.state_in: state_collection, self.dt: dt}) return result diff --git a/tests/test_struct.py b/tests/test_struct.py index ca02f52a7..36c5066e2 100644 --- a/tests/test_struct.py +++ b/tests/test_struct.py @@ -3,7 +3,7 @@ import numpy from phi.geom import box -from phi.physics.collective import CollectiveState +from phi.physics.collective import StateCollection from phi.physics.domain import Domain from phi.physics.field import CenteredGrid, manta from phi import struct @@ -17,7 +17,7 @@ def generate_test_structs(): return [manta.centered_grid(numpy.zeros([1,4,1])), [('Item',)], {'A': 'Entry A', 'Vel': manta.staggered_grid(numpy.zeros([1,5,5,2]))}, - CollectiveState((Fluid(Domain([4])),))] + StateCollection((Fluid(Domain([4])),))] class TestStruct(TestCase): diff --git a/tests/test_world.py b/tests/test_world.py index 7358fbaf2..6adecbd38 100644 --- a/tests/test_world.py +++ b/tests/test_world.py @@ -2,7 +2,7 @@ import numpy -from phi.physics.collective import CollectiveState +from phi.physics.collective import StateCollection from phi.physics.domain import Domain from phi.physics.fluid import Fluid from phi.physics.world import World @@ -11,7 +11,7 @@ class TestWorld(TestCase): def test_names(self): - c = CollectiveState() + c = StateCollection() self.assertEqual(c.states, {}) c = c.state_added(Fluid(Domain([64]))) try: