Skip to content

Commit

Permalink
Sample into Algorithm (#140)
Browse files Browse the repository at this point in the history
pushed sample into algorithms
  • Loading branch information
beardyFace authored Apr 9, 2024
1 parent b6e8b20 commit 6a3c465
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 179 deletions.
55 changes: 15 additions & 40 deletions cares_reinforcement_learning/algorithm/mbrl/DYNA_SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,21 +161,14 @@ def _train_policy(self, states, actions, rewards, next_states, dones):
info["actor_loss"] = actor_loss
return info

def train_world_model(self, experiences):
"""
Sample the buffer again for training the world model can reach higher rewards.
def train_world_model(self, memory, batch_size):

experiences = memory.sample_consecutive(batch_size)

states, actions, rewards, next_states, _, next_actions, next_rewards = (
experiences
)

:param experiences:
"""
(
states,
actions,
rewards,
next_states,
_,
next_actions,
next_rewards,
) = experiences
states = torch.FloatTensor(np.asarray(states)).to(self.device)
actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device).unsqueeze(1)
Expand All @@ -184,12 +177,8 @@ def train_world_model(self, experiences):
torch.FloatTensor(np.asarray(next_rewards)).to(self.device).unsqueeze(1)
)
next_actions = torch.FloatTensor(np.asarray(next_actions)).to(self.device)
assert len(states.shape) >= 2
assert len(actions.shape) == 2
assert len(rewards.shape) == 2 and rewards.shape[1] == 1
assert len(next_rewards.shape) == 2 and next_rewards.shape[1] == 1
assert len(next_states.shape) >= 2
# # Step 1 train the world model.

# Step 1 train the world model.
self.world_model.train_world(
states=states,
actions=actions,
Expand All @@ -199,30 +188,20 @@ def train_world_model(self, experiences):
next_rewards=next_rewards,
)

def train_policy(self, experiences):
"""
Interface to training loop.
"""
def train_policy(self, memory, batch_size):
self.learn_counter += 1
(
states,
actions,
rewards,
next_states,
dones,
) = experiences

experiences = memory.sample(batch_size)
states, actions, rewards, next_states, dones, _, _ = experiences
self.batch_size = len(states)

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device).unsqueeze(1)
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device)
dones = torch.LongTensor(np.asarray(dones)).to(self.device).unsqueeze(1)
assert len(states.shape) >= 2
assert len(actions.shape) == 2
assert len(rewards.shape) == 2 and rewards.shape[1] == 1
assert len(next_states.shape) >= 2

# Step 2 train as usual
self._train_policy(
states=states,
Expand All @@ -235,10 +214,6 @@ def train_policy(self, experiences):
self._dyna_generate_and_train(next_states=next_states)

def _dyna_generate_and_train(self, next_states):
"""
Only off-policy Dyna will work.
:param next_states:
"""
pred_states = []
pred_actions = []
pred_rs = []
Expand Down
23 changes: 4 additions & 19 deletions cares_reinforcement_learning/algorithm/policy/CTD4.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,14 @@ def fusion_kalman(self, std_1, mean_1, std_2, mean_2):
fusion_std = torch.sqrt(fusion_variance)
return fusion_mean, fusion_std

def train_policy(self, experiences):
info = {}

def train_policy(self, memory, batch_size):
self.learn_counter += 1

self.target_noise_scale *= self.noise_decay
self.target_noise_scale = max(self.min_noise, self.target_noise_scale)

states, actions, rewards, next_states, dones, indices, _ = experiences
info["indices"] = indices
experiences = memory.sample(batch_size)
states, actions, rewards, next_states, dones, _, _ = experiences

batch_size = len(states)

Expand Down Expand Up @@ -286,20 +285,6 @@ def train_policy(self, experiences):
param.data * self.tau + target_param.data * (1.0 - self.tau)
)

info["actor_loss"] = actor_loss

# Building Dictionary
# TODO David fill in info here to match other methods
# info["q_target"] = q_target
# info["q_values_one"] = q_values_one
# info["q_values_two"] = q_values_two
# info["q_values_min"] = torch.minimum(q_values_one, q_values_two)
# info["critic_loss_total"] = critic_loss_total
# info["critic_loss_one"] = critic_loss_one
# info["critic_loss_two"] = critic_loss_two

return info

def save_models(self, filename, filepath="models"):
path = f"{filepath}/models" if filepath != "models" else filepath
dir_exists = os.path.exists(path)
Expand Down
13 changes: 2 additions & 11 deletions cares_reinforcement_learning/algorithm/policy/DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,9 @@ def select_action_from_policy(self, state, evaluation=None, noise_scale=0):
self.actor_net.train()
return action

def train_policy(self, experiences):
info = {}

def train_policy(self, memory, batch_size):
experiences = memory.sample(batch_size)
states, actions, rewards, next_states, dones, indices, _ = experiences
info["indices"] = indices

batch_size = len(states)

Expand Down Expand Up @@ -103,13 +101,6 @@ def train_policy(self, experiences):
param.data * self.tau + target_param.data * (1.0 - self.tau)
)

info["actor_loss"] = actor_loss
info["critic_loss"] = critic_loss
info["q_values_min"] = q_values
info["q_values"] = q_values

return info

def save_models(self, filename, filepath="models"):
path = f"{filepath}/models" if filepath != "models" else filepath
dir_exists = os.path.exists(path)
Expand Down
19 changes: 3 additions & 16 deletions cares_reinforcement_learning/algorithm/policy/NaSATD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,16 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1):
self.actor.train()
return action

def train_policy(self, experiences):
info = {}

def train_policy(self, memory, batch_size):
self.encoder.train()
self.decoder.train()
self.actor.train()
self.critic.train()

self.learn_counter += 1

states, actions, rewards, next_states, dones, indices, _ = experiences
info["indices"] = indices
experiences = memory.sample(batch_size)
states, actions, rewards, next_states, dones, _, _ = experiences

batch_size = len(states)

Expand Down Expand Up @@ -214,17 +212,6 @@ def train_policy(self, experiences):
if self.intrinsic_on:
self.train_predictive_model(states, actions, next_states)

# Building Dictionary
info["q_target"] = q_target
info["q_values_one"] = q_values_one
info["q_values_two"] = q_values_two
info["q_values_min"] = torch.minimum(q_values_one, q_values_two)
info["critic_loss_total"] = critic_loss_total
info["critic_loss_one"] = critic_loss_one
info["critic_loss_two"] = critic_loss_two

return info

def get_intrinsic_reward(self, state, action, next_state):
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
Expand Down
19 changes: 3 additions & 16 deletions cares_reinforcement_learning/algorithm/policy/PERTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,11 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1):
self.actor_net.train()
return action

def train_policy(self, experiences):
def train_policy(self, memory, batch_size):
self.learn_counter += 1
info = {}

experiences = memory.sample(batch_size)
states, actions, rewards, next_states, dones, indices, weights = experiences
info["indices"] = indices

batch_size = len(states)

Expand Down Expand Up @@ -156,19 +155,7 @@ def train_policy(self, experiences):
param.data * self.tau + target_param.data * (1.0 - self.tau)
)

info["actor_loss"] = actor_loss

# Building Dictionary
info["q_target"] = q_target
info["q_values_one"] = q_values_one
info["q_values_two"] = q_values_two
info["q_values_min"] = torch.minimum(q_values_one, q_values_two)
info["critic_loss_total"] = critic_loss_total
info["critic_loss_one"] = critic_loss_one
info["critic_loss_two"] = critic_loss_two
info["priorities"] = priorities

return info
memory.update_priorities(indices, priorities)

def save_models(self, filename, filepath="models"):
path = f"{filepath}/models" if filepath != "models" else filepath
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/PPO.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ def calculate_rewards_to_go(self, batch_rewards, batch_dones):
batch_rtgs = torch.tensor(rtgs, dtype=torch.float).to(self.device) # shape 5000
return batch_rtgs

def train_policy(self, experience):
def train_policy(self, memory, batch_size=0):
info = {}

states, actions, rewards, next_states, dones, log_probs = experience
experiences = memory.flush()
states, actions, rewards, next_states, dones, log_probs = experiences

states = torch.FloatTensor(np.asarray(states)).to(self.device)
actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
Expand Down
20 changes: 4 additions & 16 deletions cares_reinforcement_learning/algorithm/policy/RDTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,12 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1):
self.actor_net.train()
return action

def train_policy(self, experience):
def train_policy(self, memory, batch_size):
self.learn_counter += 1
info = {}

# Sample replay buffer
states, actions, rewards, next_states, dones, indices, weights = experience
info["indices"] = indices
experiences = memory.sample(batch_size)
states, actions, rewards, next_states, dones, indices, weights = experiences

batch_size = len(states)

Expand Down Expand Up @@ -205,8 +204,6 @@ def train_policy(self, experience):
param.data * self.tau + target_param.data * (1.0 - self.tau)
)

info["actor_loss"] = actor_loss

################################################
# Update Scales
if self.learn_counter == 1:
Expand All @@ -227,16 +224,7 @@ def train_policy(self, experience):
self.scale_r = np.mean(numpy_td_err) / (np.mean(numpy_reward_err))
self.scale_s = np.mean(numpy_td_err) / (np.mean(numpy_state_err))

info["q_target"] = q_target
info["q_values_one"] = output_one
info["q_values_two"] = output_two
info["q_values_min"] = torch.minimum(output_one, output_two)
info["critic_loss_total"] = critic_loss_total
info["critic_loss_one"] = critic_one_loss
info["critic_loss_two"] = critic_two_loss
info["priorities"] = priorities

return info
memory.update_priorities(indices, priorities)

def save_models(self, filename, filepath="models"):
path = f"{filepath}/models" if filepath != "models" else filepath
Expand Down
18 changes: 3 additions & 15 deletions cares_reinforcement_learning/algorithm/policy/SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,11 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0):
def alpha(self):
return self.log_alpha.exp()

def train_policy(self, experiences):
def train_policy(self, memory, batch_size):
self.learn_counter += 1
info = {}

states, actions, rewards, next_states, dones, indices, _ = experiences
info["indices"] = indices
experiences = memory.sample(batch_size)
states, actions, rewards, next_states, dones, _, _ = experiences

batch_size = len(states)

Expand Down Expand Up @@ -154,17 +153,6 @@ def train_policy(self, experiences):
param.data * self.tau + target_param.data * (1.0 - self.tau)
)

info["q_target"] = q_target
info["q_values_one"] = q_values_one
info["q_values_two"] = q_values_two
info["q_values_min"] = torch.minimum(q_values_one, q_values_two)
info["critic_loss_total"] = critic_loss_total
info["critic_loss_one"] = critic_loss_one
info["critic_loss_two"] = critic_loss_two
info["actor_loss"] = actor_loss

return info

def save_models(self, filename, filepath="models"):
path = f"{filepath}/models" if filepath != "models" else filepath
dir_exists = os.path.exists(path)
Expand Down
20 changes: 3 additions & 17 deletions cares_reinforcement_learning/algorithm/policy/TD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,11 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1):
self.actor_net.train()
return action

def train_policy(self, experiences):
def train_policy(self, memory, batch_size):
self.learn_counter += 1
info = {}

states, actions, rewards, next_states, dones, indices, _ = experiences
info["indices"] = indices
experiences = memory.sample(batch_size)
states, actions, rewards, next_states, dones, _, _ = experiences

batch_size = len(states)

Expand Down Expand Up @@ -139,19 +138,6 @@ def train_policy(self, experiences):
param.data * self.tau + target_param.data * (1.0 - self.tau)
)

info["actor_loss"] = actor_loss

# Building Dictionary
info["q_target"] = q_target
info["q_values_one"] = q_values_one
info["q_values_two"] = q_values_two
info["q_values_min"] = torch.minimum(q_values_one, q_values_two)
info["critic_loss_total"] = critic_loss_total
info["critic_loss_one"] = critic_loss_one
info["critic_loss_two"] = critic_loss_two

return info

def save_models(self, filename, filepath="models"):
path = f"{filepath}/models" if filepath != "models" else filepath
dir_exists = os.path.exists(path)
Expand Down
6 changes: 4 additions & 2 deletions cares_reinforcement_learning/algorithm/value/DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ def select_action_from_policy(self, state):
self.network.train()
return action

def train_policy(self, experiences):
states, actions, rewards, next_states, dones = experiences
def train_policy(self, memory, batch_size):
info = {}

experiences = memory.sample(batch_size)
states, actions, rewards, next_states, dones, _, _ = experiences

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
actions = torch.LongTensor(np.asarray(actions)).to(self.device)
Expand Down
Loading

0 comments on commit 6a3c465

Please sign in to comment.