Skip to content

Commit

Permalink
Switch observations from FrozenDicts to plain Dicts (#567)
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew-Luo1 authored Dec 6, 2024
1 parent b21017a commit 49f03c0
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 25 deletions.
13 changes: 6 additions & 7 deletions brax/envs/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from brax import base
from brax.envs.base import PipelineEnv, State
from flax import core
import jax
from jax import numpy as jp

Expand Down Expand Up @@ -76,11 +75,11 @@ def reset(self, rng: jax.Array) -> State:
}

if self._obs_mode == ObservationMode.DICT_STATE:
obs = core.FrozenDict(obs)
obs = obs
elif self._obs_mode == ObservationMode.DICT_PIXELS:
obs = core.FrozenDict(pixels)
obs = pixels
elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE:
obs = core.FrozenDict({**obs, **pixels})
obs = {**obs, **pixels}
elif self._obs_mode == ObservationMode.NDARRAY:
obs = obs['state']

Expand All @@ -106,11 +105,11 @@ def step(self, state: State, action: jax.Array) -> State:
}

if self._obs_mode == ObservationMode.DICT_STATE:
obs = core.FrozenDict(obs)
obs = obs
elif self._obs_mode == ObservationMode.DICT_PIXELS:
obs = core.FrozenDict(pixels)
obs = pixels
elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE:
obs = core.FrozenDict({**obs, **pixels})
obs = {**obs, **pixels}
elif self._obs_mode == ObservationMode.NDARRAY:
obs = obs['state']

Expand Down
16 changes: 2 additions & 14 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def _random_translate_pixels(
Returns:
A dictionary of observations with translated pixels
"""
obs = core.FrozenDict(obs)

@jax.vmap
def rt_all_views(
Expand Down Expand Up @@ -127,21 +126,10 @@ def rt_view(
def _remove_pixels(
obs: Union[jnp.ndarray, Mapping[str, jax.Array]],
) -> Union[jnp.ndarray, Mapping[str, jax.Array]]:
"""Removes pixel observations from the observation dict.
FrozenDicts are used to avoid incorrect gradients.
Args:
obs: a dictionary of observations
Returns:
A dictionary of observations with pixel observations removed
"""
"""Removes pixel observations from the observation dict."""
if not isinstance(obs, Mapping):
return obs
return core.FrozenDict(
{k: v for k, v in obs.items() if not k.startswith('pixels/')}
)
return {k: v for k, v in obs.items() if not k.startswith('pixels/')}


def train(
Expand Down
6 changes: 2 additions & 4 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,11 @@ def make_policy_network_vision(
)

def apply(processor_params, policy_params, obs):
obs = core.FrozenDict(obs)
if state_obs_key:
state_obs = preprocess_observations_fn(
obs[state_obs_key], normalizer_select(processor_params, state_obs_key)
)
obs = obs.copy({state_obs_key: state_obs})
obs = core.copy(obs, {state_obs_key: state_obs})
return module.apply(policy_params, obs)

dummy_obs = {
Expand Down Expand Up @@ -405,12 +404,11 @@ def make_value_network_vision(
)

def apply(processor_params, policy_params, obs):
obs = core.FrozenDict(obs)
if state_obs_key:
state_obs = preprocess_observations_fn(
obs[state_obs_key], normalizer_select(processor_params, state_obs_key)
)
obs = obs.copy({state_obs_key: state_obs})
obs = core.copy(obs, {state_obs_key: state_obs})
return jnp.squeeze(value_module.apply(policy_params, obs), axis=-1)

dummy_obs = {
Expand Down

0 comments on commit 49f03c0

Please sign in to comment.