Skip to content

Commit

Permalink
Add pydocstyle to pre-commit, fix get_keys_to_action and environmen…
Browse files Browse the repository at this point in the history
…t frameskip (#560)
  • Loading branch information
pseudo-rnd-thoughts authored Sep 18, 2024
1 parent bb43b81 commit 768d6f4
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 95 deletions.
28 changes: 14 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -25,7 +25,7 @@ repos:
# args:
# - --ignore-words-list=
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
rev: 7.1.1
hooks:
- id: flake8
args:
Expand All @@ -36,7 +36,7 @@ repos:
- --show-source
- --statistics
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
rev: v3.17.0
hooks:
- id: pyupgrade
args: ["--py38-plus"]
Expand All @@ -46,16 +46,16 @@ repos:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/python/black
rev: 23.12.1
rev: 24.8.0
hooks:
- id: black
# - repo: https://github.com/pycqa/pydocstyle
# rev: 6.3.0
# hooks:
# - id: pydocstyle
## exclude: ^
# args:
# - --source
# - --explain
# - --convention=google
# additional_dependencies: ["tomli"]
- repo: https://github.com/pycqa/pydocstyle
rev: 6.3.0
hooks:
- id: pydocstyle
exclude: ^(docs/)|setup.py
args:
- --source
- --explain
- --convention=google
additional_dependencies: ["tomli"]
6 changes: 3 additions & 3 deletions docs/environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ The differences are listed in the following table:
|---------|--------------|------------------------------|----------------------|
| v0 | `(2, 5,)` | `0.25` | `False` |
| v4 | `(2, 5,)` | `0.0` | `False` |
| v5 | `5` | `0.25` | `False` |
| v5 | `4` | `0.25` | `False` |

> Version v5 follows the best practices outlined in [[2]](#2). Thus, it is recommended to transition to v5 and
customize the environment using the arguments above, if necessary.
Expand All @@ -220,8 +220,8 @@ are in the "ALE" namespace. The suffix "-ram" is still available. Thus, we get t

| Name | `obs_type=` | `frameskip=` | `repeat_action_probability=` |
|-------------------|-------------|--------------|------------------------------|
| ALE/Amidar-v5 | `"rgb"` | `5` | `0.25` |
| ALE/Amidar-ram-v5 | `"ram"` | `5` | `0.25` |
| ALE/Amidar-v5 | `"rgb"` | `4` | `0.25` |
| ALE/Amidar-ram-v5 | `"ram"` | `4` | `0.25` |

## Flavors

Expand Down
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Setup file for ALE."""

import os
import re
import subprocess
Expand All @@ -6,7 +8,7 @@
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext

here = os.path.abspath(os.path.dirname(__file__))
current_working_file = os.path.abspath(os.path.dirname(__file__))


class CMakeExtension(Extension):
Expand Down Expand Up @@ -108,8 +110,7 @@ def build_extension(self, ext):


def parse_version(version_file):
"""
Parse version from `version_file`.
"""Parse version from `version_file`.
If we're running on CI, i.e., CIBUILDWHEEL is set, then we'll parse
the version from `GITHUB_REF` using the official semver regex.
Expand All @@ -132,7 +133,7 @@ def parse_version(version_file):

if __name__ == "__main__":
# Allow for running `pip wheel` from other directories
here and os.chdir(here)
current_working_file and os.chdir(current_working_file)
# Most config options are in `setup.cfg`. These are the
# only dynamic options we need at build time.
setup(
Expand Down
2 changes: 2 additions & 0 deletions src/ale/python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Python module for interacting with ALE c++ interface and gymnasium wrapper."""

import os
import platform
import sys
Expand Down
2 changes: 2 additions & 0 deletions src/ale/python/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LoggerMode:
"""
:type: str
"""

@property
def value(self) -> int:
"""
Expand All @@ -54,6 +55,7 @@ class Action:
"""
:type: str
"""

@property
def value(self) -> int:
"""
Expand Down
79 changes: 38 additions & 41 deletions src/ale/python/env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Gymnasium wrapper around the Arcade Learning Environment (ALE)."""

from __future__ import annotations

import sys
Expand All @@ -19,17 +21,16 @@


class AtariEnvStepMetadata(TypedDict):
"""Step info options."""

lives: int
episode_frame_number: int
frame_number: int
seeds: NotRequired[tuple[int, int]]


class AtariEnv(gymnasium.Env, utils.EzPickle):
"""
(A)rcade (L)earning (Gym) (Env)ironment.
A Gym wrapper around the Arcade Learning Environment (ALE).
"""
"""Gymnasium wrapper around the Arcade Learning Environment (ALE)."""

# FPS can differ per ROM, therefore, dynamically collect the fps once the game is loaded
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
Expand All @@ -48,8 +49,8 @@ def __init__(
max_num_frames_per_episode: int | None = None,
render_mode: Literal["human", "rgb_array"] | None = None,
):
"""
Initialize the ALE for Gymnasium.
"""Initialize the ALE for Gymnasium.
Default parameters are taken from Machado et al., 2018.
Args:
Expand Down Expand Up @@ -219,13 +220,13 @@ def load_game(self) -> None:
if self._game_difficulty is not None:
self.ale.setDifficulty(self._game_difficulty)

def reset( # pyright: ignore[reportIncompatibleMethodOverride]
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[np.ndarray, AtariEnvStepMetadata]:
"""Resets environment and returns initial observation."""
"""Resets environment and returns initial episode observation."""
super().reset(seed=seed, options=options)

# sets the seeds if it's specified for both ALE and frameskip np
Expand All @@ -248,8 +249,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride]
self,
action: int | np.ndarray,
) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]:
"""
Perform one agent step, i.e., repeats `action` frameskip # of steps.
"""Perform one agent step, i.e., repeats `action` frameskip # of steps.
Args:
action: int | np.ndarray =>
Expand All @@ -260,8 +260,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride]
tuple[np.ndarray, float, bool, bool, Dict[str, Any]] =>
observation, reward, terminal, truncation, metadata
Note: `metadata` contains the keys "lives" and "rgb" if
render_mode == 'rgb_array'.
Note: `metadata` contains the keys "lives".
"""
# If frameskip is a length 2 tuple then it's stochastic
# frameskip between [frameskip[0], frameskip[1]] uniformly.
Expand Down Expand Up @@ -305,15 +304,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride]
return self._get_obs(), reward, is_terminal, is_truncated, self._get_info()

def render(self) -> np.ndarray | None:
"""
Render is not supported by ALE. We use a paradigm similar to
Gym3 which allows you to specify `render_mode` during construction.
For example,
gym.make("ale-py:Pong-v0", render_mode="human")
will display the ALE and maintain the proper interval to match the
FPS target set by the ROM.
"""
"""Renders the ALE with `rgb_array` and `human` options."""
if self.render_mode == "rgb_array":
return self.ale.getScreenRGB()
elif self.render_mode == "human":
Expand All @@ -325,18 +316,17 @@ def render(self) -> np.ndarray | None:
)

def _get_obs(self) -> np.ndarray:
"""
Retrieves the current observation.
This is dependent on `self._obs_type`.
"""
"""Retrieves the current observation using `obs_type`."""
if self._obs_type == "ram":
return self.ale.getRAM()
elif self._obs_type == "rgb":
return self.ale.getScreenRGB()
elif self._obs_type == "grayscale":
return self.ale.getScreenGrayscale()
else:
raise error.Error(f"Unrecognized observation type: {self._obs_type}")
raise error.Error(
f"Unrecognized observation type: {self._obs_type}, expected: 'ram', 'rgb' and 'grayscale'."
)

def _get_info(self) -> AtariEnvStepMetadata:
return {
Expand All @@ -347,17 +337,23 @@ def _get_info(self) -> AtariEnvStepMetadata:

@lru_cache(1)
def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]:
"""
Return keymapping -> actions for human play.
"""Return keymapping -> actions for human play.
Up, down, left and right are wasd keys with fire being space.
No op is 'e'
Returns:
Dictionary of key values to actions
"""
UP = ord("w")
LEFT = ord("a")
RIGHT = ord("d")
DOWN = ord("s")
FIRE = ord(" ")
NOOP = ord("e")

mapping = {
ale_py.Action.NOOP: (None,),
ale_py.Action.NOOP: (NOOP,),
ale_py.Action.UP: (UP,),
ale_py.Action.FIRE: (FIRE,),
ale_py.Action.DOWN: (DOWN,),
Expand Down Expand Up @@ -389,9 +385,7 @@ def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]:
def map_action_idx(
self, left_center_right: int, down_center_up: int, fire: bool
) -> int:
"""
Return an action idx given unit actions for underlying env.
"""
"""Return an action idx given unit actions for underlying env."""
# no op and fire
if left_center_right == 0 and down_center_up == 0 and not fire:
return ale_py.Action.NOOP
Expand Down Expand Up @@ -441,25 +435,28 @@ def map_action_idx(
# just in case
else:
raise LookupError(
"Did not expect to get here, "
"expected `left_center_right` and `down_center_up` to be in {-1, 0, 1} "
"and `fire` to only be `True` or `False`. "
"Unexpected action mapping, expected `left_center_right` and `down_center_up` to be in {-1, 0, 1} and `fire` to only be `True` or `False`. "
f"Received {left_center_right=}, {down_center_up=} and {fire=}."
)

def get_action_meanings(self) -> list[str]:
"""
Return the meaning of each integer action.
"""
"""Return the meaning of each action."""
keys = ale_py.Action.__members__.values()
values = ale_py.Action.__members__.keys()
mapping = dict(zip(keys, values))
return [mapping[action] for action in self._action_set]

def clone_state(self, include_rng: bool = False) -> ale_py.ALEState:
"""Clone emulator state w/o system state. Restoring this state will
*not* give an identical environment. For complete cloning and restoring
of the full state, see `{clone,restore}_full_state()`."""
"""Clone emulator state.
To reproduce identical states, specify `include_rng` to `True`.
Args:
include_rng: If to include the system RNG within the state
Returns:
The cloned ALE state
"""
return self.ale.cloneState(include_rng=include_rng)

def restore_state(self, state: ale_py.ALEState) -> None:
Expand Down
17 changes: 10 additions & 7 deletions src/ale/python/registration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Registration for Atari environments."""

from __future__ import annotations

from collections import defaultdict
Expand All @@ -8,25 +10,24 @@


class EnvFlavour(NamedTuple):
"""Environment flavour for env id suffix and kwargs."""

suffix: str
kwargs: Mapping[str, Any] | Callable[[str], Mapping[str, Any]]


class EnvConfig(NamedTuple):
"""Environment config for version, kwargs and flavours."""

version: str
kwargs: Mapping[str, Any]
flavours: Sequence[EnvFlavour]


def _rom_id_to_name(rom: str) -> str:
"""
Let the ROM ID be the ROM identifier in snake_case.
For example, `space_invaders`
The ROM name is the ROM ID in pascalcase.
For example, `SpaceInvaders`
"""Converts the Rom ID (snake_case) to ROM name in PascalCase.
This function converts the ROM ID to the ROM name.
i.e., snakecase -> pascalcase
For example, `space_invaders` to `SpaceInvaders`
"""
return rom.title().replace("_", "")

Expand Down Expand Up @@ -72,6 +73,7 @@ def _register_rom_configs(


def register_v0_v4_envs():
"""Registers all v0 and v4 environments."""
legacy_games = [
"adventure",
"air_raid",
Expand Down Expand Up @@ -177,6 +179,7 @@ def register_v0_v4_envs():


def register_v5_envs():
"""Register all v5 environments."""
all_games = roms.get_all_rom_ids()
obs_types = ["rgb", "ram"]

Expand Down
2 changes: 2 additions & 0 deletions src/ale/python/roms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Rom module with functions for collecting individual and all ROMS files."""

from __future__ import annotations

import functools
Expand Down
Loading

0 comments on commit 768d6f4

Please sign in to comment.