Skip to content

Commit

Permalink
from old to new reset/step API
Browse files Browse the repository at this point in the history
  • Loading branch information
robinhenry committed Nov 13, 2024
1 parent 070e6de commit bd2db8b
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 24 deletions.
8 changes: 4 additions & 4 deletions examples/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@

def run():
env = gym.make("gym_anm:ANM6Easy-v0")
o = env.reset()
o, _ = env.reset()

for i in range(10):
a = env.action_space.sample()
o, r, done, info = env.step(a)
o, r, terminated, _, info = env.step(a)
env.render()
time.sleep(0.5) # otherwise the rendering is too fast for the human eye

if done:
o = env.reset()
if terminated:
o, _ = env.reset()
env.close()


Expand Down
8 changes: 4 additions & 4 deletions gym_anm/envs/anm6_env/anm6.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,21 @@ def render(self, mode="human", skip_frames=0):
self._update_render(dev_p, dev_q, branch_s, des_soc, gen_p_max, bus_v_magn, costs, network_collapsed)

def step(self, action):
obs, r, done, info = super().step(action)
obs, r, terminated, truncated, info = super().step(action)

# Increment the date (for rendering).
self.date += self.timestep_length

# Increment the year count.
self.year_count = (self.date - self.date_init).days // 365

return obs, r, done, info
return obs, r, terminated, truncated, info

def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
# Save rendering setup to restore after the reset().
render_mode = self.render_mode

obs = super().reset(seed=seed, options=options)
obs, info = super().reset(seed=seed, options=options)

# Restore the rendering setup.
self.render_mode = render_mode
Expand All @@ -138,7 +138,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self.date_init = random_date(self.np_random, 2020)
self.date = self.date_init

return obs
return obs, info

def reset_date(self, date_init):
"""Reset the date displayed in the visualization (and the year count)."""
Expand Down
7 changes: 3 additions & 4 deletions gym_anm/envs/anm6_env/anm6_easy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,13 @@ def next_vars(self, s_t):
return np.array(vars)

def reset(self, **kwargs):
obs = super().reset(**kwargs)
obs, info = super().reset(**kwargs)

# Reset the time of the day based on the auxiliary variable.
date = self.date
new_date = self.date + self.state[-1] * self.timestep_length
super().reset_date(new_date)

return obs
return obs, info


def _get_load_time_series():
Expand Down Expand Up @@ -145,7 +144,7 @@ def _get_gen_time_series():
for i in range(T):
print(i)
a = env.action_space.sample()
o, r, _, _ = env.step(a)
o, r, _, _, _ = env.step(a)
env.render()
time.sleep(0.5)

Expand Down
28 changes: 17 additions & 11 deletions gym_anm/envs/anm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ANMEnv(gym.Env):
The observation space from which observation vectors are constructed.
observation_N : int
The number of observation variables.
done : bool
terminated : bool
True if a terminal state has been reached (if the network collapsed);
False otherwise.
render_mode : str
Expand Down Expand Up @@ -256,7 +256,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):

super().reset(seed=seed, options=options)

self.done = False
self.terminated = False
self.render_mode = None
self.timestep = 0
self.e_loss = 0.0
Expand Down Expand Up @@ -304,11 +304,11 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):

# Cast state and obs vectors to 0 (arbitrary) if a terminal state has
# been reached.
if self.done:
if self.terminated:
self.state = self._terminal_state(self.state_N)
obs = self._terminal_state(self.observation_N)

return obs
return obs, {}

def observation(self, s_t):
"""
Expand Down Expand Up @@ -345,20 +345,26 @@ def step(self, action):
The observation vector :math:`o_{t+1}`.
reward : float
The reward associated with the transition :math:`r_t`.
done : bool
terminated : bool
True if a terminal state has been reached; False otherwise.
truncated: bool
True if the episode was truncated; False otherwise. Always False here.
info : dict
A dictionary with further information (used for debugging).
"""

err_msg = "Action %r (%s) invalid." % (action, type(action))
assert self.action_space.contains(action), err_msg

# Fix the truncated flag and info dict
truncated = False
info = {}

# 0. Remain in a terminal state and output reward=0 if the environment
# has already reached a terminal state.
if self.done:
if self.terminated:
obs = self._terminal_state(self.observation_N)
return obs, 0.0, self.done, {}
return obs, 0.0, self.terminated, truncated, info

# 1a. Sample the internal stochastic variables.
vars = self.next_vars(self.state)
Expand Down Expand Up @@ -412,10 +418,10 @@ def step(self, action):

# A terminal state has been reached if no solution to the power
# flow equations is found.
self.done = not pfe_converged
self.terminated = not pfe_converged

# 3b. Clip the reward.
if not self.done:
if not self.terminated:
self.e_loss = np.sign(e_loss) * np.clip(np.abs(e_loss), 0, self.costs_clipping[0])
self.penalty = np.clip(penalty, 0, self.costs_clipping[1])
r = -(self.e_loss + self.penalty)
Expand All @@ -426,7 +432,7 @@ def step(self, action):
self.penalty = self.costs_clipping[1]

# 4. Construct the state and observation vector.
if not self.done:
if not self.terminated:
for k in range(self.K):
self.state[k - self.K] = aux[k]
self.state = self._construct_state()
Expand All @@ -444,7 +450,7 @@ def step(self, action):
# 5. Update the timestep.
self.timestep += 1

return obs, r, self.done, {}
return obs, r, self.terminated, truncated, info

def render(self, mode="human"):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dcopf_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self):
done = True
while done:
self.env.reset()
done = self.env.done
done = self.env.terminated

self.safety_margin = 0.9
self.B = self.env.simulator.Y_bus.imag.toarray()
Expand Down

0 comments on commit bd2db8b

Please sign in to comment.