From 09671d72db87d887c918316acfbc4a9e4b610d72 Mon Sep 17 00:00:00 2001 From: JD Ranpariya Date: Sun, 6 Oct 2024 15:44:37 +0200 Subject: [PATCH 1/8] add support for dtype in discrete space --- gymnasium/spaces/discrete.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/gymnasium/spaces/discrete.py b/gymnasium/spaces/discrete.py index 9a4575252..586d8220f 100644 --- a/gymnasium/spaces/discrete.py +++ b/gymnasium/spaces/discrete.py @@ -27,6 +27,7 @@ class Discrete(Space[np.int64]): def __init__( self, n: int | np.integer[Any], + dtype: str | type[np.integer[Any]] = np.int64, seed: int | np.random.Generator | None = None, start: int | np.integer[Any] = 0, ): @@ -36,6 +37,7 @@ def __init__( Args: n (int): The number of elements of this space. + dtype: This should be some kind of integer type. seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space. start (int): The smallest element of this space. """ @@ -47,9 +49,22 @@ def __init__( type(start), np.integer ), f"Expects `start` to be an integer, actual type: {type(start)}" - self.n = np.int64(n) - self.start = np.int64(start) - super().__init__((), np.int64, seed) + # determine dtype + if dtype is None: + raise ValueError( + "Discrete dtype must be explicitly provided, cannot be None." + ) + self.dtype = np.dtype(dtype) + + # * check that dtype is an accepted dtype + if not (np.issubdtype(self.dtype, np.integer)): + raise ValueError( + f"Invalid Discrete dtype ({self.dtype}), must be an integer dtype" + ) + + self.n = self.dtype.type(n) + self.start = self.dtype.type(start) + super().__init__((), self.dtype, seed) @property def is_np_flattenable(self): From 4dcd0b6ca4c81051b33f3dca99a4ea442e705d0e Mon Sep 17 00:00:00 2001 From: JD Ranpariya Date: Sun, 6 Oct 2024 16:53:44 +0200 Subject: [PATCH 2/8] fix sample method to correctly return set dtype --- gymnasium/spaces/discrete.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gymnasium/spaces/discrete.py b/gymnasium/spaces/discrete.py index 586d8220f..d8d7d5c58 100644 --- a/gymnasium/spaces/discrete.py +++ b/gymnasium/spaces/discrete.py @@ -71,7 +71,7 @@ def is_np_flattenable(self): """Checks whether this space can be flattened to a :class:`spaces.Box`.""" return True - def sample(self, mask: MaskNDArray | None = None) -> np.int64: + def sample(self, mask: MaskNDArray | None = None) -> np.integer[Any]: """Generates a single random sample from this space. A sample will be chosen uniformly at random with the mask if provided @@ -99,13 +99,13 @@ def sample(self, mask: MaskNDArray | None = None) -> np.int64: np.logical_or(mask == 0, valid_action_mask) ), f"All values of a mask should be 0 or 1, actual values: {mask}" if np.any(valid_action_mask): - return self.start + self.np_random.choice( + return self.start + self.dtype.type(self.np_random.choice( np.where(valid_action_mask)[0] - ) + )) else: return self.start - return self.start + self.np_random.integers(self.n) + return self.start + self.dtype.type(self.np_random.integers(self.n)) def contains(self, x: Any) -> bool: """Return boolean specifying if x is a valid member of this space.""" @@ -152,7 +152,7 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]): super().__setstate__(state) - def to_jsonable(self, sample_n: Sequence[np.int64]) -> list[int]: + def to_jsonable(self, sample_n: Sequence[np.integer[Any]]) -> list[int]: """Converts a list of samples to a list of ints.""" return [int(x) for x in sample_n] From e9efa3d99490598d541984e418fa629c285555e2 Mon Sep 17 00:00:00 2001 From: JD Ranpariya Date: Sun, 6 Oct 2024 16:56:10 +0200 Subject: [PATCH 3/8] fix sample method to correctly return set dtype --- gymnasium/spaces/discrete.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gymnasium/spaces/discrete.py b/gymnasium/spaces/discrete.py index d8d7d5c58..0cc7b96d0 100644 --- a/gymnasium/spaces/discrete.py +++ b/gymnasium/spaces/discrete.py @@ -99,9 +99,9 @@ def sample(self, mask: MaskNDArray | None = None) -> np.integer[Any]: np.logical_or(mask == 0, valid_action_mask) ), f"All values of a mask should be 0 or 1, actual values: {mask}" if np.any(valid_action_mask): - return self.start + self.dtype.type(self.np_random.choice( - np.where(valid_action_mask)[0] - )) + return self.start + self.dtype.type( + self.np_random.choice(np.where(valid_action_mask)[0]) + ) else: return self.start From f636835179c9b458dc6afcc7dac3f06f413ef47d Mon Sep 17 00:00:00 2001 From: JD Ranpariya Date: Sun, 6 Oct 2024 20:45:36 +0200 Subject: [PATCH 4/8] fix dtype conversion --- gymnasium/spaces/discrete.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gymnasium/spaces/discrete.py b/gymnasium/spaces/discrete.py index 0cc7b96d0..966390ca9 100644 --- a/gymnasium/spaces/discrete.py +++ b/gymnasium/spaces/discrete.py @@ -105,7 +105,7 @@ def sample(self, mask: MaskNDArray | None = None) -> np.integer[Any]: else: return self.start - return self.start + self.dtype.type(self.np_random.integers(self.n)) + return self.start + self.np_random.integers(self.n).astype(self.dtype) def contains(self, x: Any) -> bool: """Return boolean specifying if x is a valid member of this space.""" From 1e8f4c184f08f45752f8f6e80f6c2e450f7a6aa0 Mon Sep 17 00:00:00 2001 From: JD Ranpariya Date: Mon, 7 Oct 2024 14:14:25 +0200 Subject: [PATCH 5/8] Add Discrete(dtype) tests --- tests/spaces/test_discrete.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/spaces/test_discrete.py b/tests/spaces/test_discrete.py index 71c4fcf51..ca1c249ef 100644 --- a/tests/spaces/test_discrete.py +++ b/tests/spaces/test_discrete.py @@ -1,6 +1,7 @@ from copy import deepcopy import numpy as np +import pytest from gymnasium.spaces import Discrete @@ -32,3 +33,30 @@ def test_sample_mask(): assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3 assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2 assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5] + + +@pytest.mark.parametrize("dtype, sample_dtype", [ + (int, np.int64), + (np.int64, np.int64), + (np.int32, np.int32), + (np.uint8, np.uint8), +]) +def test_dtype(dtype, sample_dtype): + space = Discrete(n=5, dtype=dtype, start=2) + + sample = space.sample() + sample_mask = space.sample(mask=np.array([0, 1, 0, 0, 0], dtype=np.int8)) + print(f'{sample=}, {sample_mask=}') + print(f'{type(sample)=}, {type(sample_mask)=}') + assert isinstance(sample, sample_dtype), type(sample) + assert isinstance(sample_mask, sample_dtype), type(sample_mask) + + +@pytest.mark.parametrize("dtype", [ + str, + np.float32, + np.complex64, +]) +def test_dtype_error(dtype): + with pytest.raises(ValueError, match="Invalid Discrete dtype"): + Discrete(4, dtype=dtype) From 90ca3525955dfbcc9bf671f8d2fc5e4a4505ab0c Mon Sep 17 00:00:00 2001 From: JD Ranpariya Date: Mon, 7 Oct 2024 14:17:28 +0200 Subject: [PATCH 6/8] Add Discrete(dtype) tests --- tests/spaces/test_discrete.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/tests/spaces/test_discrete.py b/tests/spaces/test_discrete.py index ca1c249ef..8bc9261ef 100644 --- a/tests/spaces/test_discrete.py +++ b/tests/spaces/test_discrete.py @@ -35,28 +35,34 @@ def test_sample_mask(): assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5] -@pytest.mark.parametrize("dtype, sample_dtype", [ - (int, np.int64), - (np.int64, np.int64), - (np.int32, np.int32), - (np.uint8, np.uint8), -]) +@pytest.mark.parametrize( + "dtype, sample_dtype", + [ + (int, np.int64), + (np.int64, np.int64), + (np.int32, np.int32), + (np.uint8, np.uint8), + ], +) def test_dtype(dtype, sample_dtype): space = Discrete(n=5, dtype=dtype, start=2) sample = space.sample() sample_mask = space.sample(mask=np.array([0, 1, 0, 0, 0], dtype=np.int8)) - print(f'{sample=}, {sample_mask=}') - print(f'{type(sample)=}, {type(sample_mask)=}') + print(f"{sample=}, {sample_mask=}") + print(f"{type(sample)=}, {type(sample_mask)=}") assert isinstance(sample, sample_dtype), type(sample) assert isinstance(sample_mask, sample_dtype), type(sample_mask) -@pytest.mark.parametrize("dtype", [ - str, - np.float32, - np.complex64, -]) +@pytest.mark.parametrize( + "dtype", + [ + str, + np.float32, + np.complex64, + ], +) def test_dtype_error(dtype): with pytest.raises(ValueError, match="Invalid Discrete dtype"): Discrete(4, dtype=dtype) From e6619a2ded35a58abe3100b2aebe618065591818 Mon Sep 17 00:00:00 2001 From: JD Ranpariya Date: Mon, 7 Oct 2024 16:05:43 +0200 Subject: [PATCH 7/8] Add test for dtype None --- gymnasium/spaces/discrete.py | 4 +--- tests/spaces/test_discrete.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/gymnasium/spaces/discrete.py b/gymnasium/spaces/discrete.py index 966390ca9..278985f58 100644 --- a/gymnasium/spaces/discrete.py +++ b/gymnasium/spaces/discrete.py @@ -51,9 +51,7 @@ def __init__( # determine dtype if dtype is None: - raise ValueError( - "Discrete dtype must be explicitly provided, cannot be None." - ) + raise ValueError("Invalid Discrete dtype ({self.dtype}), cannot be None.") self.dtype = np.dtype(dtype) # * check that dtype is an accepted dtype diff --git a/tests/spaces/test_discrete.py b/tests/spaces/test_discrete.py index 8bc9261ef..c94905795 100644 --- a/tests/spaces/test_discrete.py +++ b/tests/spaces/test_discrete.py @@ -58,6 +58,7 @@ def test_dtype(dtype, sample_dtype): @pytest.mark.parametrize( "dtype", [ + None, str, np.float32, np.complex64, From a6ed721bd22ea0e75a05493a3363c4f2cd198a7f Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Mon, 7 Oct 2024 15:14:51 +0100 Subject: [PATCH 8/8] Improve dtype docstring --- gymnasium/spaces/discrete.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gymnasium/spaces/discrete.py b/gymnasium/spaces/discrete.py index 278985f58..54e7616ee 100644 --- a/gymnasium/spaces/discrete.py +++ b/gymnasium/spaces/discrete.py @@ -37,7 +37,7 @@ def __init__( Args: n (int): The number of elements of this space. - dtype: This should be some kind of integer type. + dtype: The space type, for example, ``int``, ``np.int64``, ``np.int32``, or ``np.uint8``. seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space. start (int): The smallest element of this space. """