Skip to content

Commit

Permalink
Rename CollectiveState to StateCollection
Browse files Browse the repository at this point in the history
This better reflects the fact that CollectiveState is not a State.
  • Loading branch information
holl- committed Jan 25, 2020
1 parent ac659bf commit 4312c66
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 34 deletions.
2 changes: 1 addition & 1 deletion documentation/Scene_Format_Specification.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ The following content was created by running the [simpleplume.py](../demos/simpl
"module": "phi.physics.objects"
}
],
"type": "CollectiveState",
"type": "StateCollection",
"module": "phi.physics.collective"
}
}
Expand Down
43 changes: 23 additions & 20 deletions phi/physics/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@struct.definition()
class CollectiveState(struct.Struct):
class StateCollection(struct.Struct):

def __init__(self, states=None, **kwargs):
struct.Struct.__init__(self, **struct.kwargs(locals()))
Expand Down Expand Up @@ -115,58 +115,61 @@ def shape(self):
return struct.map(lambda state: state.shape, self, recursive=False)


CollectiveState = StateCollection


class CollectivePhysics(Physics):

def __init__(self):
Physics.__init__(self, {})
self.physics = {} # map from name to Physics

def step(self, collectivestate, dt=1.0, **dependent_states):
def step(self, state_collection, dt=1.0, **dependent_states):
assert len(dependent_states) == 0
if len(collectivestate) == 0:
return collectivestate
unhandled_states = list(collectivestate.states.values())
if len(state_collection) == 0:
return state_collection
unhandled_states = list(state_collection.states.values())
next_states = {}
partial_next_collectivestate = CollectiveState(next_states)
partial_next_state_collection = StateCollection(next_states)

for sweep in range(len(collectivestate)):
for sweep in range(len(state_collection)):
for state in tuple(unhandled_states):
physics = self.for_(state)
if self._all_dependencies_fulfilled(physics.blocking_dependencies, collectivestate, partial_next_collectivestate):
next_state = self.substep(state, collectivestate, dt, partial_next_collectivestate=partial_next_collectivestate)
if self._all_dependencies_fulfilled(physics.blocking_dependencies, state_collection, partial_next_state_collection):
next_state = self.substep(state, state_collection, dt, partial_next_state_collection=partial_next_state_collection)
assert next_state.name == state.name, "The state name must remain constant during step(). Caused by '%s' on state '%s'." % (type(physics).__name__, state)
next_states[next_state.name] = next_state
unhandled_states.remove(state)
partial_next_collectivestate = CollectiveState(next_states)
partial_next_state_collection = StateCollection(next_states)
if len(unhandled_states) == 0:
ordered_states = [partial_next_collectivestate[state] for state in collectivestate.states]
return partial_next_collectivestate.copied_with(states=ordered_states)
ordered_states = [partial_next_state_collection[state] for state in state_collection.states]
return partial_next_state_collection.copied_with(states=ordered_states)

# Error
errstr = 'Cyclic blocking_dependencies in simulation: %s' % unhandled_states
for state in tuple(unhandled_states):
physics = self.for_(state)
state_dict = self._gather_dependencies(physics.blocking_dependencies, collectivestate, {})
state_dict = self._gather_dependencies(physics.blocking_dependencies, state_collection, {})
errstr += '\nState "%s" with physics "%s" depends on %s' % (state, physics, state_dict)
raise AssertionError(errstr)

def substep(self, state, collectivestate, dt, override_physics=None, partial_next_collectivestate=None):
def substep(self, state, state_collection, dt, override_physics=None, partial_next_state_collection=None):
physics = self.for_(state) if override_physics is None else override_physics
# --- gather dependencies
dependent_states = {}
self._gather_dependencies(physics.dependencies, collectivestate, dependent_states)
if partial_next_collectivestate is not None:
self._gather_dependencies(physics.blocking_dependencies, partial_next_collectivestate, dependent_states)
self._gather_dependencies(physics.dependencies, state_collection, dependent_states)
if partial_next_state_collection is not None:
self._gather_dependencies(physics.blocking_dependencies, partial_next_state_collection, dependent_states)
# --- execute step ---
next_state = physics.step(state, dt, **dependent_states)
return next_state

def _gather_dependencies(self, dependencies, collectivestate, result_dict):
def _gather_dependencies(self, dependencies, state_collection, result_dict):
for statedependency in dependencies:
if statedependency.state_name is not None:
matching_states = collectivestate.find(statedependency.state_name)
matching_states = state_collection.find(statedependency.state_name)
else:
matching_states = collectivestate.all_with_tag(statedependency.tag)
matching_states = state_collection.all_with_tag(statedependency.tag)
if statedependency.single_state:
assert len(matching_states) == 1, 'Dependency %s requires 1 state but found %d' % (statedependency, len(matching_states))
value = matching_states[0]
Expand Down
2 changes: 1 addition & 1 deletion phi/physics/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def name(self, name):
Names uniquely identify the system represented by this state.
All states that represent a configuration of the same system must have the same name.
Names can also be used as a shortcut to reference states (e.g. in CollectiveState or World).
Names can also be used as a shortcut to reference states (e.g. in StateCollection or World).
"""
if name is None:
return '%s_%d' % (self.__class__.__name__.lower(), id(self))
Expand Down
12 changes: 6 additions & 6 deletions phi/physics/world.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Worlds are used to manage simulations consisting of multiple states.
A world uses a CollectiveState with CollectivePhysics to resolve state dependencies.
A world uses a StateCollection with CollectivePhysics to resolve state dependencies.
Worlds also facilitate referencing states when performing a forward simulation.
A default World, called `world` is provided for convenience.
Expand All @@ -11,7 +11,7 @@

import six

from .collective import CollectiveState
from .collective import StateCollection
from .field.effect import Gravity
from .physics import Physics, State, Static

Expand Down Expand Up @@ -122,7 +122,7 @@ def reset(self, batch_size=None, add_default_objects=True):
:param batch_size: int or None
:param add_default_objects: if True, adds defaults like Gravity
"""
self._state = CollectiveState()
self._state = StateCollection()
self.physics = self._state.default_physics()
self.observers = set()
self.batch_size = batch_size
Expand All @@ -133,7 +133,7 @@ def reset(self, batch_size=None, add_default_objects=True):
def state(self):
"""
Returns the current state of the world.
:return: CollectiveState
:return: StateCollection
"""
return self._state

Expand All @@ -148,10 +148,10 @@ def age(self):
def state(self, state):
"""
Sets the current state of the world and informs all observers.
:param state: CollectiveState
:param state: StateCollection
"""
assert state is not None
assert isinstance(state, CollectiveState)
assert isinstance(state, StateCollection)
self._state = state
for observer in self.observers:
observer(self)
Expand Down
4 changes: 2 additions & 2 deletions phi/tf/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(self, physics, session, state_in, state_out, dt):
self.session = session
self.dt = dt

def step(self, collectivestate, dt=1.0, **dependent_states):
result = self.session.run(self.state_out, {self.state_in: collectivestate, self.dt: dt})
def step(self, state_collection, dt=1.0, **dependent_states):
result = self.session.run(self.state_out, {self.state_in: state_collection, self.dt: dt})
return result


Expand Down
4 changes: 2 additions & 2 deletions tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy

from phi.geom import box
from phi.physics.collective import CollectiveState
from phi.physics.collective import StateCollection
from phi.physics.domain import Domain
from phi.physics.field import CenteredGrid, manta
from phi import struct
Expand All @@ -17,7 +17,7 @@ def generate_test_structs():
return [manta.centered_grid(numpy.zeros([1,4,1])),
[('Item',)],
{'A': 'Entry A', 'Vel': manta.staggered_grid(numpy.zeros([1,5,5,2]))},
CollectiveState((Fluid(Domain([4])),))]
StateCollection((Fluid(Domain([4])),))]


class TestStruct(TestCase):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy

from phi.physics.collective import CollectiveState
from phi.physics.collective import StateCollection
from phi.physics.domain import Domain
from phi.physics.fluid import Fluid
from phi.physics.world import World
Expand All @@ -11,7 +11,7 @@
class TestWorld(TestCase):

def test_names(self):
c = CollectiveState()
c = StateCollection()
self.assertEqual(c.states, {})
c = c.state_added(Fluid(Domain([64])))
try:
Expand Down

0 comments on commit 4312c66

Please sign in to comment.