Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
howardh committed Jan 9, 2025
1 parent 702b21d commit 2d98fc8
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 45 deletions.
23 changes: 9 additions & 14 deletions gymnasium/wrappers/vector/vectorize_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from gymnasium import Space
from gymnasium.core import ActType, Env
from gymnasium.logger import warn
from gymnasium.vector import VectorActionWrapper, VectorEnv
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
from gymnasium.wrappers import transform_action
Expand Down Expand Up @@ -73,32 +74,26 @@ def __init__(
"""
super().__init__(env)

self._single_action_space_error = None
self._single_action_space = self.env.single_action_space
if action_space is None:
if single_action_space is not None:
self.single_action_space = single_action_space
self.action_space = batch_space(single_action_space, self.num_envs)
self._single_action_space = single_action_space
else:
self.action_space = action_space
if single_action_space is None:
self._single_action_space_error = "`single_action_space` not defined. A new action space was provided to the TransformAction wrapper, but not the single action space."
else:
self._single_action_space = single_action_space
if single_action_space is not None:
self.single_action_space = single_action_space
# TODO: We could compute single_action_space from the action_space if only the latter is provided and avoid the warning below.
if self.action_space != batch_space(self.single_action_space, self.num_envs):
warn(
"The action space and the batched single action space don't match as expected."
)

self.func = func

def actions(self, actions: ActType) -> ActType:
"""Applies the :attr:`func` to the actions."""
return self.func(actions)

@property
def single_action_space(self) -> Space:
"""The single action space of the environment."""
if self._single_action_space_error is not None:
raise AttributeError(self._single_action_space_error)
return self._single_action_space


class VectorizeTransformAction(VectorActionWrapper):
"""Vectorizes a single-agent transform action wrapper for vector environments.
Expand Down
23 changes: 9 additions & 14 deletions gymnasium/wrappers/vector/vectorize_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,35 +69,30 @@ def __init__(
"""
super().__init__(env)

self._single_observation_space_error = None
self._single_observation_space = self.env.single_observation_space
if observation_space is None:
if single_observation_space is not None:
self.single_observation_space = single_observation_space
self.observation_space = batch_space(
single_observation_space, self.num_envs
)
self._single_observation_space = single_observation_space
else:
self.observation_space = observation_space
if single_observation_space is None:
# TODO: We could compute this from the observation_space.
self._single_observation_space_error = "`single_observation_space` not defined. A new observation space was provided to the TransformObservation wrapper, but not the single observation space."
else:
if single_observation_space is not None:
self._single_observation_space = single_observation_space
# TODO: We could compute single_observation_space from the observation_space if only the latter is provided and avoid the warning below.
if self.observation_space != batch_space(
self.single_observation_space, self.num_envs
):
warn(
"The observation space and the batched single observation space don't match as expected."
)

self.func = func

def observations(self, observations: ObsType) -> ObsType:
"""Apply function to the vector observation."""
return self.func(observations)

@property
def single_observation_space(self) -> Space:
"""Returns the single observation space."""
if self._single_observation_space_error is not None:
raise AttributeError(self._single_observation_space_error)
return self._single_observation_space


class VectorizeTransformObservation(VectorObservationWrapper):
"""Vectorizes a single-agent transform observation wrapper for vector environments.
Expand Down
28 changes: 11 additions & 17 deletions tests/wrappers/vector/test_transform_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,17 @@ def test_observation_space_from_single_observation_space(
).all()


def test_error_on_unspecified_single_observation_space(
def test_warning_on_mismatched_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
observation_space=spaces.Box(
low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100,
high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100,
),
)

# Environment should still work normally
obs, _ = vec_env.reset()
obs, *_ = vec_env.step(vec_env.action_space.sample())

# But if we try to access the single_observation_space, it should error
with pytest.raises(AttributeError):
vec_env.single_observation_space
# We only specify observation_space without single_observation_space, so single_observation_space inherits its value from the wrapped env which would not match. This mismatch should give us a warning.
with pytest.warns(Warning):
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
observation_space=spaces.Box(
low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100,
high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100,
),
)

0 comments on commit 2d98fc8

Please sign in to comment.