From bd2db8b460c626208e401335b50463843273b33f Mon Sep 17 00:00:00 2001 From: Robin Henry Date: Wed, 13 Nov 2024 09:01:46 +0000 Subject: [PATCH] from old to new reset/step API --- examples/random_agent.py | 8 ++++---- gym_anm/envs/anm6_env/anm6.py | 8 ++++---- gym_anm/envs/anm6_env/anm6_easy.py | 7 +++---- gym_anm/envs/anm_env.py | 28 +++++++++++++++++----------- tests/test_dcopf_agent.py | 2 +- 5 files changed, 29 insertions(+), 24 deletions(-) diff --git a/examples/random_agent.py b/examples/random_agent.py index 6a73d3b..10eb8e9 100644 --- a/examples/random_agent.py +++ b/examples/random_agent.py @@ -11,16 +11,16 @@ def run(): env = gym.make("gym_anm:ANM6Easy-v0") - o = env.reset() + o, _ = env.reset() for i in range(10): a = env.action_space.sample() - o, r, done, info = env.step(a) + o, r, terminated, _, info = env.step(a) env.render() time.sleep(0.5) # otherwise the rendering is too fast for the human eye - if done: - o = env.reset() + if terminated: + o, _ = env.reset() env.close() diff --git a/gym_anm/envs/anm6_env/anm6.py b/gym_anm/envs/anm6_env/anm6.py index f2701b0..6b68658 100644 --- a/gym_anm/envs/anm6_env/anm6.py +++ b/gym_anm/envs/anm6_env/anm6.py @@ -111,7 +111,7 @@ def render(self, mode="human", skip_frames=0): self._update_render(dev_p, dev_q, branch_s, des_soc, gen_p_max, bus_v_magn, costs, network_collapsed) def step(self, action): - obs, r, done, info = super().step(action) + obs, r, terminated, truncated, info = super().step(action) # Increment the date (for rendering). self.date += self.timestep_length @@ -119,13 +119,13 @@ def step(self, action): # Increment the year count. self.year_count = (self.date - self.date_init).days // 365 - return obs, r, done, info + return obs, r, terminated, truncated, info def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): # Save rendering setup to restore after the reset(). render_mode = self.render_mode - obs = super().reset(seed=seed, options=options) + obs, info = super().reset(seed=seed, options=options) # Restore the rendering setup. self.render_mode = render_mode @@ -138,7 +138,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): self.date_init = random_date(self.np_random, 2020) self.date = self.date_init - return obs + return obs, info def reset_date(self, date_init): """Reset the date displayed in the visualization (and the year count).""" diff --git a/gym_anm/envs/anm6_env/anm6_easy.py b/gym_anm/envs/anm6_env/anm6_easy.py index 5b66fe1..b128e53 100644 --- a/gym_anm/envs/anm6_env/anm6_easy.py +++ b/gym_anm/envs/anm6_env/anm6_easy.py @@ -65,14 +65,13 @@ def next_vars(self, s_t): return np.array(vars) def reset(self, **kwargs): - obs = super().reset(**kwargs) + obs, info = super().reset(**kwargs) # Reset the time of the day based on the auxiliary variable. - date = self.date new_date = self.date + self.state[-1] * self.timestep_length super().reset_date(new_date) - return obs + return obs, info def _get_load_time_series(): @@ -145,7 +144,7 @@ def _get_gen_time_series(): for i in range(T): print(i) a = env.action_space.sample() - o, r, _, _ = env.step(a) + o, r, _, _, _ = env.step(a) env.render() time.sleep(0.5) diff --git a/gym_anm/envs/anm_env.py b/gym_anm/envs/anm_env.py index 9c55976..770cad0 100644 --- a/gym_anm/envs/anm_env.py +++ b/gym_anm/envs/anm_env.py @@ -53,7 +53,7 @@ class ANMEnv(gym.Env): The observation space from which observation vectors are constructed. observation_N : int The number of observation variables. - done : bool + terminated : bool True if a terminal state has been reached (if the network collapsed); False otherwise. render_mode : str @@ -256,7 +256,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): super().reset(seed=seed, options=options) - self.done = False + self.terminated = False self.render_mode = None self.timestep = 0 self.e_loss = 0.0 @@ -304,11 +304,11 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): # Cast state and obs vectors to 0 (arbitrary) if a terminal state has # been reached. - if self.done: + if self.terminated: self.state = self._terminal_state(self.state_N) obs = self._terminal_state(self.observation_N) - return obs + return obs, {} def observation(self, s_t): """ @@ -345,20 +345,26 @@ def step(self, action): The observation vector :math:`o_{t+1}`. reward : float The reward associated with the transition :math:`r_t`. - done : bool + terminated : bool True if a terminal state has been reached; False otherwise. + truncated: bool + True if the episode was truncated; False otherwise. Always False here. info : dict A dictionary with further information (used for debugging). """ err_msg = "Action %r (%s) invalid." % (action, type(action)) assert self.action_space.contains(action), err_msg + + # Fix the truncated flag and info dict + truncated = False + info = {} # 0. Remain in a terminal state and output reward=0 if the environment # has already reached a terminal state. - if self.done: + if self.terminated: obs = self._terminal_state(self.observation_N) - return obs, 0.0, self.done, {} + return obs, 0.0, self.terminated, truncated, info # 1a. Sample the internal stochastic variables. vars = self.next_vars(self.state) @@ -412,10 +418,10 @@ def step(self, action): # A terminal state has been reached if no solution to the power # flow equations is found. - self.done = not pfe_converged + self.terminated = not pfe_converged # 3b. Clip the reward. - if not self.done: + if not self.terminated: self.e_loss = np.sign(e_loss) * np.clip(np.abs(e_loss), 0, self.costs_clipping[0]) self.penalty = np.clip(penalty, 0, self.costs_clipping[1]) r = -(self.e_loss + self.penalty) @@ -426,7 +432,7 @@ def step(self, action): self.penalty = self.costs_clipping[1] # 4. Construct the state and observation vector. - if not self.done: + if not self.terminated: for k in range(self.K): self.state[k - self.K] = aux[k] self.state = self._construct_state() @@ -444,7 +450,7 @@ def step(self, action): # 5. Update the timestep. self.timestep += 1 - return obs, r, self.done, {} + return obs, r, self.terminated, truncated, info def render(self, mode="human"): """ diff --git a/tests/test_dcopf_agent.py b/tests/test_dcopf_agent.py index 37efd0b..357c7ab 100644 --- a/tests/test_dcopf_agent.py +++ b/tests/test_dcopf_agent.py @@ -19,7 +19,7 @@ def setUp(self): done = True while done: self.env.reset() - done = self.env.done + done = self.env.terminated self.safety_margin = 0.9 self.B = self.env.simulator.Y_bus.imag.toarray()