Skip to content

Commit

Permalink
Add wrappers.vector.TransformObs/Action single obs/action space arg…
Browse files Browse the repository at this point in the history
…ument (#1288)

Co-authored-by: Howard <[email protected]>
  • Loading branch information
howardh and howardh authored Jan 12, 2025
1 parent c6c5815 commit 75bd3be
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 4 deletions.
18 changes: 16 additions & 2 deletions gymnasium/wrappers/vector/vectorize_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from gymnasium import Space
from gymnasium.core import ActType, Env
from gymnasium.logger import warn
from gymnasium.vector import VectorActionWrapper, VectorEnv
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
from gymnasium.wrappers import transform_action
Expand Down Expand Up @@ -61,18 +62,31 @@ def __init__(
env: VectorEnv,
func: Callable[[ActType], Any],
action_space: Space | None = None,
single_action_space: Space | None = None,
):
"""Constructor for the lambda action wrapper.
Args:
env: The vector environment to wrap
func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``.
action_space: The action spaces of the wrapper. If None, then it is computed from ``single_action_space``. If ``single_action_space`` is not provided either, then it is assumed to be the same as ``env.action_space``.
single_action_space: The action space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_action_space``.
"""
super().__init__(env)

if action_space is not None:
if action_space is None:
if single_action_space is not None:
self.single_action_space = single_action_space
self.action_space = batch_space(single_action_space, self.num_envs)
else:
self.action_space = action_space
if single_action_space is not None:
self.single_action_space = single_action_space
# TODO: We could compute single_action_space from the action_space if only the latter is provided and avoid the warning below.
if self.action_space != batch_space(self.single_action_space, self.num_envs):
warn(
f"For {env}, the action space and the batched single action space don't match as expected, action_space={env.action_space}, batched single_action_space={batch_space(self.single_action_space, self.num_envs)}"
)

self.func = func

Expand Down
21 changes: 19 additions & 2 deletions gymnasium/wrappers/vector/vectorize_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,35 @@ def __init__(
env: VectorEnv,
func: Callable[[ObsType], Any],
observation_space: Space | None = None,
single_observation_space: Space | None = None,
):
"""Constructor for the transform observation wrapper.
Args:
env: The vector environment to wrap
func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
observation_space: The observation spaces of the wrapper. If None, then it is computed from ``single_observation_space``. If ``single_observation_space`` is not provided either, then it is assumed to be the same as ``env.observation_space``.
single_observation_space: The observation space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_observation_space``.
"""
super().__init__(env)

if observation_space is not None:
if observation_space is None:
if single_observation_space is not None:
self.single_observation_space = single_observation_space
self.observation_space = batch_space(
single_observation_space, self.num_envs
)
else:
self.observation_space = observation_space
if single_observation_space is not None:
self._single_observation_space = single_observation_space
# TODO: We could compute single_observation_space from the observation_space if only the latter is provided and avoid the warning below.
if self.observation_space != batch_space(
self.single_observation_space, self.num_envs
):
warn(
f"For {env}, the observation space and the batched single observation space don't match as expected, observation_space={env.observation_space}, batched single_observation_space={batch_space(self.single_observation_space, self.num_envs)}"
)

self.func = func

Expand Down
73 changes: 73 additions & 0 deletions tests/wrappers/vector/test_transform_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Test suite for vector TransformAction wrapper."""

import numpy as np
import pytest

from gymnasium import spaces, wrappers
from gymnasium.vector import SyncVectorEnv
from tests.testing_env import GenericTestEnv


def create_env():
return GenericTestEnv(
action_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
)
)


def test_action_space_from_single_action_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformAction(
vec_env,
func=lambda x: x + 100,
single_action_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32) + 100,
high=np.array([10, -5, 10], dtype=np.float32) + 100,
),
)

# Check action space
assert isinstance(vec_env.action_space, spaces.Box)
assert vec_env.action_space.shape == (n_envs, 3)
assert vec_env.action_space.dtype == np.float32
assert (
vec_env.action_space.low == np.array([[100, 90, 95]] * n_envs, dtype=np.float32)
).all()
assert (
vec_env.action_space.high
== np.array([[110, 95, 110]] * n_envs, dtype=np.float32)
).all()

# Check single action space
assert isinstance(vec_env.single_action_space, spaces.Box)
assert vec_env.single_action_space.shape == (3,)
assert vec_env.single_action_space.dtype == np.float32
assert (
vec_env.single_action_space.low == np.array([100, 90, 95], dtype=np.float32)
).all()
assert (
vec_env.single_action_space.high == np.array([110, 95, 110], dtype=np.float32)
).all()


def test_warning_on_mismatched_single_action_space(
n_envs: int = 2,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
# We only specify action_space without single_action_space, so single_action_space inherits its value from the wrapped env which would not match. This mismatch should give us a warning.
with pytest.warns(
Warning,
match=r"the action space and the batched single action space don't match as expected",
):
vec_env = wrappers.vector.TransformAction(
vec_env,
func=lambda x: x + 100,
action_space=spaces.Box(
low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100,
high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100,
),
)
100 changes: 100 additions & 0 deletions tests/wrappers/vector/test_transform_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Test suite for vector TransformObservation wrapper."""

import numpy as np
import pytest

from gymnasium import spaces, wrappers
from gymnasium.vector import SyncVectorEnv
from tests.testing_env import GenericTestEnv


def create_env():
return GenericTestEnv(
observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
)
)


def test_transform(n_envs: int = 2):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
single_observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
),
)

obs, _ = vec_env.reset(seed=123)
vec_env.observation_space.seed(123)
vec_env.action_space.seed(123)

assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all()
assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all()

obs, *_ = vec_env.step(vec_env.action_space.sample())

assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all()
assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all()


def test_observation_space_from_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
single_observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32) + 100,
high=np.array([10, -5, 10], dtype=np.float32) + 100,
),
)

# Check observation space
assert isinstance(vec_env.observation_space, spaces.Box)
assert vec_env.observation_space.shape == (n_envs, 3)
assert vec_env.observation_space.dtype == np.float32
assert (
vec_env.observation_space.low
== np.array([[100, 90, 95]] * n_envs, dtype=np.float32)
).all()
assert (
vec_env.observation_space.high
== np.array([[110, 95, 110]] * n_envs, dtype=np.float32)
).all()

# Check single observation space
assert isinstance(vec_env.single_observation_space, spaces.Box)
assert vec_env.single_observation_space.shape == (3,)
assert vec_env.single_observation_space.dtype == np.float32
assert (
vec_env.single_observation_space.low
== np.array([100, 90, 95], dtype=np.float32)
).all()
assert (
vec_env.single_observation_space.high
== np.array([110, 95, 110], dtype=np.float32)
).all()


def test_warning_on_mismatched_single_observation_space(
n_envs: int = 2,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
# We only specify observation_space without single_observation_space, so single_observation_space inherits its value from the wrapped env which would not match. This mismatch should give us a warning.
with pytest.warns(
Warning,
match=r"the observation space and the batched single observation space don't match as expected",
):
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
observation_space=spaces.Box(
low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100,
high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100,
),
)

0 comments on commit 75bd3be

Please sign in to comment.