Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wrappers.vector.TransformObs/Action single obs/action space argument #1288

Merged
merged 5 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions gymnasium/wrappers/vector/vectorize_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from gymnasium import Space
from gymnasium.core import ActType, Env
from gymnasium.logger import warn
from gymnasium.vector import VectorActionWrapper, VectorEnv
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
from gymnasium.wrappers import transform_action
Expand Down Expand Up @@ -61,18 +62,31 @@ def __init__(
env: VectorEnv,
func: Callable[[ActType], Any],
action_space: Space | None = None,
single_action_space: Space | None = None,
):
"""Constructor for the lambda action wrapper.

Args:
env: The vector environment to wrap
func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``.
action_space: The action spaces of the wrapper. If None, then it is computed from ``single_action_space``. If ``single_action_space`` is not provided either, then it is assumed to be the same as ``env.action_space``.
single_action_space: The action space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_action_space``.
"""
super().__init__(env)

if action_space is not None:
if action_space is None:
if single_action_space is not None:
self.single_action_space = single_action_space
self.action_space = batch_space(single_action_space, self.num_envs)
else:
self.action_space = action_space
if single_action_space is not None:
self.single_action_space = single_action_space
# TODO: We could compute single_action_space from the action_space if only the latter is provided and avoid the warning below.
if self.action_space != batch_space(self.single_action_space, self.num_envs):
warn(
"The action space and the batched single action space don't match as expected."
howardh marked this conversation as resolved.
Show resolved Hide resolved
)

self.func = func

Expand Down
21 changes: 19 additions & 2 deletions gymnasium/wrappers/vector/vectorize_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,35 @@ def __init__(
env: VectorEnv,
func: Callable[[ObsType], Any],
observation_space: Space | None = None,
single_observation_space: Space | None = None,
):
"""Constructor for the transform observation wrapper.

Args:
env: The vector environment to wrap
func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
observation_space: The observation spaces of the wrapper. If None, then it is computed from ``single_observation_space``. If ``single_observation_space`` is not provided either, then it is assumed to be the same as ``env.observation_space``.
single_observation_space: The observation space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_observation_space``.
"""
super().__init__(env)

if observation_space is not None:
if observation_space is None:
if single_observation_space is not None:
self.single_observation_space = single_observation_space
self.observation_space = batch_space(
single_observation_space, self.num_envs
)
else:
self.observation_space = observation_space
if single_observation_space is not None:
self._single_observation_space = single_observation_space
# TODO: We could compute single_observation_space from the observation_space if only the latter is provided and avoid the warning below.
if self.observation_space != batch_space(
self.single_observation_space, self.num_envs
):
warn(
"The observation space and the batched single observation space don't match as expected."
)

self.func = func

Expand Down
53 changes: 53 additions & 0 deletions tests/wrappers/vector/test_transform_action.py
howardh marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Test suite for vector TransformAction wrapper."""

import numpy as np

from gymnasium import spaces, wrappers
from gymnasium.vector import SyncVectorEnv
from tests.testing_env import GenericTestEnv


def create_env():
return GenericTestEnv(
action_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
)
)


def test_observation_space_from_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformAction(
vec_env,
func=lambda x: x + 100,
single_action_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32) + 100,
high=np.array([10, -5, 10], dtype=np.float32) + 100,
),
)

# Check action space
assert isinstance(vec_env.action_space, spaces.Box)
assert vec_env.action_space.shape == (n_envs, 3)
assert vec_env.action_space.dtype == np.float32
assert (
vec_env.action_space.low == np.array([[100, 90, 95]] * n_envs, dtype=np.float32)
).all()
assert (
vec_env.action_space.high
== np.array([[110, 95, 110]] * n_envs, dtype=np.float32)
).all()

# Check single action space
assert isinstance(vec_env.single_action_space, spaces.Box)
assert vec_env.single_action_space.shape == (3,)
assert vec_env.single_action_space.dtype == np.float32
assert (
vec_env.single_action_space.low == np.array([100, 90, 95], dtype=np.float32)
).all()
assert (
vec_env.single_action_space.high == np.array([110, 95, 110], dtype=np.float32)
).all()
97 changes: 97 additions & 0 deletions tests/wrappers/vector/test_transform_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Test suite for vector TransformObservation wrapper."""

import numpy as np
import pytest

from gymnasium import spaces, wrappers
from gymnasium.vector import SyncVectorEnv
from tests.testing_env import GenericTestEnv


def create_env():
return GenericTestEnv(
observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
)
)


def test_transform(n_envs: int = 2):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
single_observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
),
)

obs, _ = vec_env.reset(seed=123)
vec_env.observation_space.seed(123)
vec_env.action_space.seed(123)

assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all()
assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all()

obs, *_ = vec_env.step(vec_env.action_space.sample())

assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all()
assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all()


def test_observation_space_from_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
single_observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32) + 100,
high=np.array([10, -5, 10], dtype=np.float32) + 100,
),
)

# Check observation space
assert isinstance(vec_env.observation_space, spaces.Box)
assert vec_env.observation_space.shape == (n_envs, 3)
assert vec_env.observation_space.dtype == np.float32
assert (
vec_env.observation_space.low
== np.array([[100, 90, 95]] * n_envs, dtype=np.float32)
).all()
assert (
vec_env.observation_space.high
== np.array([[110, 95, 110]] * n_envs, dtype=np.float32)
).all()

# Check single observation space
assert isinstance(vec_env.single_observation_space, spaces.Box)
assert vec_env.single_observation_space.shape == (3,)
assert vec_env.single_observation_space.dtype == np.float32
assert (
vec_env.single_observation_space.low
== np.array([100, 90, 95], dtype=np.float32)
).all()
assert (
vec_env.single_observation_space.high
== np.array([110, 95, 110], dtype=np.float32)
).all()


def test_warning_on_mismatched_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
# We only specify observation_space without single_observation_space, so single_observation_space inherits its value from the wrapped env which would not match. This mismatch should give us a warning.
with pytest.warns(Warning):
howardh marked this conversation as resolved.
Show resolved Hide resolved
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
observation_space=spaces.Box(
low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100,
high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100,
),
)
Loading