diff --git a/cares_reinforcement_learning/algorithm/mbrl/DYNA_SAC.py b/cares_reinforcement_learning/algorithm/mbrl/DYNA_SAC.py index f9d37bac..2c2078c9 100644 --- a/cares_reinforcement_learning/algorithm/mbrl/DYNA_SAC.py +++ b/cares_reinforcement_learning/algorithm/mbrl/DYNA_SAC.py @@ -161,21 +161,14 @@ def _train_policy(self, states, actions, rewards, next_states, dones): info["actor_loss"] = actor_loss return info - def train_world_model(self, experiences): - """ - Sample the buffer again for training the world model can reach higher rewards. + def train_world_model(self, memory, batch_size): + + experiences = memory.sample_consecutive(batch_size) + + states, actions, rewards, next_states, _, next_actions, next_rewards = ( + experiences + ) - :param experiences: - """ - ( - states, - actions, - rewards, - next_states, - _, - next_actions, - next_rewards, - ) = experiences states = torch.FloatTensor(np.asarray(states)).to(self.device) actions = torch.FloatTensor(np.asarray(actions)).to(self.device) rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device).unsqueeze(1) @@ -184,12 +177,8 @@ def train_world_model(self, experiences): torch.FloatTensor(np.asarray(next_rewards)).to(self.device).unsqueeze(1) ) next_actions = torch.FloatTensor(np.asarray(next_actions)).to(self.device) - assert len(states.shape) >= 2 - assert len(actions.shape) == 2 - assert len(rewards.shape) == 2 and rewards.shape[1] == 1 - assert len(next_rewards.shape) == 2 and next_rewards.shape[1] == 1 - assert len(next_states.shape) >= 2 - # # Step 1 train the world model. + + # Step 1 train the world model. self.world_model.train_world( states=states, actions=actions, @@ -199,30 +188,20 @@ def train_world_model(self, experiences): next_rewards=next_rewards, ) - def train_policy(self, experiences): - """ - Interface to training loop. - - """ + def train_policy(self, memory, batch_size): self.learn_counter += 1 - ( - states, - actions, - rewards, - next_states, - dones, - ) = experiences + + experiences = memory.sample(batch_size) + states, actions, rewards, next_states, dones, _, _ = experiences self.batch_size = len(states) + # Convert into tensor states = torch.FloatTensor(np.asarray(states)).to(self.device) actions = torch.FloatTensor(np.asarray(actions)).to(self.device) rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device).unsqueeze(1) next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device) dones = torch.LongTensor(np.asarray(dones)).to(self.device).unsqueeze(1) - assert len(states.shape) >= 2 - assert len(actions.shape) == 2 - assert len(rewards.shape) == 2 and rewards.shape[1] == 1 - assert len(next_states.shape) >= 2 + # Step 2 train as usual self._train_policy( states=states, @@ -235,10 +214,6 @@ def train_policy(self, experiences): self._dyna_generate_and_train(next_states=next_states) def _dyna_generate_and_train(self, next_states): - """ - Only off-policy Dyna will work. - :param next_states: - """ pred_states = [] pred_actions = [] pred_rs = [] diff --git a/cares_reinforcement_learning/algorithm/policy/CTD4.py b/cares_reinforcement_learning/algorithm/policy/CTD4.py index 29ff0ab4..c85f53a7 100644 --- a/cares_reinforcement_learning/algorithm/policy/CTD4.py +++ b/cares_reinforcement_learning/algorithm/policy/CTD4.py @@ -91,15 +91,14 @@ def fusion_kalman(self, std_1, mean_1, std_2, mean_2): fusion_std = torch.sqrt(fusion_variance) return fusion_mean, fusion_std - def train_policy(self, experiences): - info = {} - + def train_policy(self, memory, batch_size): self.learn_counter += 1 + self.target_noise_scale *= self.noise_decay self.target_noise_scale = max(self.min_noise, self.target_noise_scale) - states, actions, rewards, next_states, dones, indices, _ = experiences - info["indices"] = indices + experiences = memory.sample(batch_size) + states, actions, rewards, next_states, dones, _, _ = experiences batch_size = len(states) @@ -286,20 +285,6 @@ def train_policy(self, experiences): param.data * self.tau + target_param.data * (1.0 - self.tau) ) - info["actor_loss"] = actor_loss - - # Building Dictionary - # TODO David fill in info here to match other methods - # info["q_target"] = q_target - # info["q_values_one"] = q_values_one - # info["q_values_two"] = q_values_two - # info["q_values_min"] = torch.minimum(q_values_one, q_values_two) - # info["critic_loss_total"] = critic_loss_total - # info["critic_loss_one"] = critic_loss_one - # info["critic_loss_two"] = critic_loss_two - - return info - def save_models(self, filename, filepath="models"): path = f"{filepath}/models" if filepath != "models" else filepath dir_exists = os.path.exists(path) diff --git a/cares_reinforcement_learning/algorithm/policy/DDPG.py b/cares_reinforcement_learning/algorithm/policy/DDPG.py index 6516c588..4dbd0e99 100644 --- a/cares_reinforcement_learning/algorithm/policy/DDPG.py +++ b/cares_reinforcement_learning/algorithm/policy/DDPG.py @@ -48,11 +48,9 @@ def select_action_from_policy(self, state, evaluation=None, noise_scale=0): self.actor_net.train() return action - def train_policy(self, experiences): - info = {} - + def train_policy(self, memory, batch_size): + experiences = memory.sample(batch_size) states, actions, rewards, next_states, dones, indices, _ = experiences - info["indices"] = indices batch_size = len(states) @@ -103,13 +101,6 @@ def train_policy(self, experiences): param.data * self.tau + target_param.data * (1.0 - self.tau) ) - info["actor_loss"] = actor_loss - info["critic_loss"] = critic_loss - info["q_values_min"] = q_values - info["q_values"] = q_values - - return info - def save_models(self, filename, filepath="models"): path = f"{filepath}/models" if filepath != "models" else filepath dir_exists = os.path.exists(path) diff --git a/cares_reinforcement_learning/algorithm/policy/NaSATD3.py b/cares_reinforcement_learning/algorithm/policy/NaSATD3.py index 20424181..d900a414 100644 --- a/cares_reinforcement_learning/algorithm/policy/NaSATD3.py +++ b/cares_reinforcement_learning/algorithm/policy/NaSATD3.py @@ -101,9 +101,7 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1): self.actor.train() return action - def train_policy(self, experiences): - info = {} - + def train_policy(self, memory, batch_size): self.encoder.train() self.decoder.train() self.actor.train() @@ -111,8 +109,8 @@ def train_policy(self, experiences): self.learn_counter += 1 - states, actions, rewards, next_states, dones, indices, _ = experiences - info["indices"] = indices + experiences = memory.sample(batch_size) + states, actions, rewards, next_states, dones, _, _ = experiences batch_size = len(states) @@ -214,17 +212,6 @@ def train_policy(self, experiences): if self.intrinsic_on: self.train_predictive_model(states, actions, next_states) - # Building Dictionary - info["q_target"] = q_target - info["q_values_one"] = q_values_one - info["q_values_two"] = q_values_two - info["q_values_min"] = torch.minimum(q_values_one, q_values_two) - info["critic_loss_total"] = critic_loss_total - info["critic_loss_one"] = critic_loss_one - info["critic_loss_two"] = critic_loss_two - - return info - def get_intrinsic_reward(self, state, action, next_state): with torch.no_grad(): state_tensor = torch.FloatTensor(state).to(self.device) diff --git a/cares_reinforcement_learning/algorithm/policy/PERTD3.py b/cares_reinforcement_learning/algorithm/policy/PERTD3.py index e5db0baf..fbba5c24 100644 --- a/cares_reinforcement_learning/algorithm/policy/PERTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/PERTD3.py @@ -62,12 +62,11 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1): self.actor_net.train() return action - def train_policy(self, experiences): + def train_policy(self, memory, batch_size): self.learn_counter += 1 - info = {} + experiences = memory.sample(batch_size) states, actions, rewards, next_states, dones, indices, weights = experiences - info["indices"] = indices batch_size = len(states) @@ -156,19 +155,7 @@ def train_policy(self, experiences): param.data * self.tau + target_param.data * (1.0 - self.tau) ) - info["actor_loss"] = actor_loss - - # Building Dictionary - info["q_target"] = q_target - info["q_values_one"] = q_values_one - info["q_values_two"] = q_values_two - info["q_values_min"] = torch.minimum(q_values_one, q_values_two) - info["critic_loss_total"] = critic_loss_total - info["critic_loss_one"] = critic_loss_one - info["critic_loss_two"] = critic_loss_two - info["priorities"] = priorities - - return info + memory.update_priorities(indices, priorities) def save_models(self, filename, filepath="models"): path = f"{filepath}/models" if filepath != "models" else filepath diff --git a/cares_reinforcement_learning/algorithm/policy/PPO.py b/cares_reinforcement_learning/algorithm/policy/PPO.py index 9284cbd4..2318e8af 100644 --- a/cares_reinforcement_learning/algorithm/policy/PPO.py +++ b/cares_reinforcement_learning/algorithm/policy/PPO.py @@ -84,10 +84,11 @@ def calculate_rewards_to_go(self, batch_rewards, batch_dones): batch_rtgs = torch.tensor(rtgs, dtype=torch.float).to(self.device) # shape 5000 return batch_rtgs - def train_policy(self, experience): + def train_policy(self, memory, batch_size=0): info = {} - states, actions, rewards, next_states, dones, log_probs = experience + experiences = memory.flush() + states, actions, rewards, next_states, dones, log_probs = experiences states = torch.FloatTensor(np.asarray(states)).to(self.device) actions = torch.FloatTensor(np.asarray(actions)).to(self.device) diff --git a/cares_reinforcement_learning/algorithm/policy/RDTD3.py b/cares_reinforcement_learning/algorithm/policy/RDTD3.py index cb9c9534..fb3a42db 100644 --- a/cares_reinforcement_learning/algorithm/policy/RDTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/RDTD3.py @@ -71,13 +71,12 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1): self.actor_net.train() return action - def train_policy(self, experience): + def train_policy(self, memory, batch_size): self.learn_counter += 1 - info = {} # Sample replay buffer - states, actions, rewards, next_states, dones, indices, weights = experience - info["indices"] = indices + experiences = memory.sample(batch_size) + states, actions, rewards, next_states, dones, indices, weights = experiences batch_size = len(states) @@ -205,8 +204,6 @@ def train_policy(self, experience): param.data * self.tau + target_param.data * (1.0 - self.tau) ) - info["actor_loss"] = actor_loss - ################################################ # Update Scales if self.learn_counter == 1: @@ -227,16 +224,7 @@ def train_policy(self, experience): self.scale_r = np.mean(numpy_td_err) / (np.mean(numpy_reward_err)) self.scale_s = np.mean(numpy_td_err) / (np.mean(numpy_state_err)) - info["q_target"] = q_target - info["q_values_one"] = output_one - info["q_values_two"] = output_two - info["q_values_min"] = torch.minimum(output_one, output_two) - info["critic_loss_total"] = critic_loss_total - info["critic_loss_one"] = critic_one_loss - info["critic_loss_two"] = critic_two_loss - info["priorities"] = priorities - - return info + memory.update_priorities(indices, priorities) def save_models(self, filename, filepath="models"): path = f"{filepath}/models" if filepath != "models" else filepath diff --git a/cares_reinforcement_learning/algorithm/policy/SAC.py b/cares_reinforcement_learning/algorithm/policy/SAC.py index bc61212b..1332934d 100644 --- a/cares_reinforcement_learning/algorithm/policy/SAC.py +++ b/cares_reinforcement_learning/algorithm/policy/SAC.py @@ -86,12 +86,11 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0): def alpha(self): return self.log_alpha.exp() - def train_policy(self, experiences): + def train_policy(self, memory, batch_size): self.learn_counter += 1 - info = {} - states, actions, rewards, next_states, dones, indices, _ = experiences - info["indices"] = indices + experiences = memory.sample(batch_size) + states, actions, rewards, next_states, dones, _, _ = experiences batch_size = len(states) @@ -154,17 +153,6 @@ def train_policy(self, experiences): param.data * self.tau + target_param.data * (1.0 - self.tau) ) - info["q_target"] = q_target - info["q_values_one"] = q_values_one - info["q_values_two"] = q_values_two - info["q_values_min"] = torch.minimum(q_values_one, q_values_two) - info["critic_loss_total"] = critic_loss_total - info["critic_loss_one"] = critic_loss_one - info["critic_loss_two"] = critic_loss_two - info["actor_loss"] = actor_loss - - return info - def save_models(self, filename, filepath="models"): path = f"{filepath}/models" if filepath != "models" else filepath dir_exists = os.path.exists(path) diff --git a/cares_reinforcement_learning/algorithm/policy/TD3.py b/cares_reinforcement_learning/algorithm/policy/TD3.py index aef32f53..63971ff0 100644 --- a/cares_reinforcement_learning/algorithm/policy/TD3.py +++ b/cares_reinforcement_learning/algorithm/policy/TD3.py @@ -60,12 +60,11 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1): self.actor_net.train() return action - def train_policy(self, experiences): + def train_policy(self, memory, batch_size): self.learn_counter += 1 - info = {} - states, actions, rewards, next_states, dones, indices, _ = experiences - info["indices"] = indices + experiences = memory.sample(batch_size) + states, actions, rewards, next_states, dones, _, _ = experiences batch_size = len(states) @@ -139,19 +138,6 @@ def train_policy(self, experiences): param.data * self.tau + target_param.data * (1.0 - self.tau) ) - info["actor_loss"] = actor_loss - - # Building Dictionary - info["q_target"] = q_target - info["q_values_one"] = q_values_one - info["q_values_two"] = q_values_two - info["q_values_min"] = torch.minimum(q_values_one, q_values_two) - info["critic_loss_total"] = critic_loss_total - info["critic_loss_one"] = critic_loss_one - info["critic_loss_two"] = critic_loss_two - - return info - def save_models(self, filename, filepath="models"): path = f"{filepath}/models" if filepath != "models" else filepath dir_exists = os.path.exists(path) diff --git a/cares_reinforcement_learning/algorithm/value/DQN.py b/cares_reinforcement_learning/algorithm/value/DQN.py index de7e5898..b68d1980 100644 --- a/cares_reinforcement_learning/algorithm/value/DQN.py +++ b/cares_reinforcement_learning/algorithm/value/DQN.py @@ -27,10 +27,12 @@ def select_action_from_policy(self, state): self.network.train() return action - def train_policy(self, experiences): - states, actions, rewards, next_states, dones = experiences + def train_policy(self, memory, batch_size): info = {} + experiences = memory.sample(batch_size) + states, actions, rewards, next_states, dones, _, _ = experiences + # Convert into tensor states = torch.FloatTensor(np.asarray(states)).to(self.device) actions = torch.LongTensor(np.asarray(actions)).to(self.device) diff --git a/cares_reinforcement_learning/algorithm/value/DoubleDQN.py b/cares_reinforcement_learning/algorithm/value/DoubleDQN.py index ee25ad3b..2b8069b3 100644 --- a/cares_reinforcement_learning/algorithm/value/DoubleDQN.py +++ b/cares_reinforcement_learning/algorithm/value/DoubleDQN.py @@ -35,9 +35,9 @@ def select_action_from_policy(self, state): self.network.train() return action - def train_policy(self, experiences): - states, actions, rewards, next_states, dones = experiences - info = {} + def train_policy(self, memory, batch_size): + experiences = memory.sample(batch_size) + states, actions, rewards, next_states, dones, _, _ = experiences # Convert into tensor states = torch.FloatTensor(np.asarray(states)).to(self.device) @@ -70,12 +70,6 @@ def train_policy(self, experiences): param.data * self.tau + target_param.data * (1.0 - self.tau) ) - info["q_target"] = q_target - info["q_values_min"] = q_value - info["network_loss"] = loss - - return info - def save_models(self, filename, filepath="models"): path = f"{filepath}/models" if filepath != "models" else filepath dir_exists = os.path.exists(path) diff --git a/cares_reinforcement_learning/memory/prioritised_replay_buffer.py b/cares_reinforcement_learning/memory/prioritised_replay_buffer.py index d46191c2..cf6f837b 100644 --- a/cares_reinforcement_learning/memory/prioritised_replay_buffer.py +++ b/cares_reinforcement_learning/memory/prioritised_replay_buffer.py @@ -120,28 +120,19 @@ def sample(self, batch_size): weights.tolist(), ) - def update_priority(self, info): + def update_priorities(self, indices, priorities): """ - Update the priorities of the replay buffer based on the given information. + Update the priorities of the replay buffer at the given indices. - Args: - info (dict): A dictionary containing the following keys: - - "indices" (list): A list of indices corresponding to the samples in the replay buffer. - - "priorities" (torch.Tensor, optional): A tensor containing the new priorities for the samples. - If not provided, default priorities of 1.0 are assigned to all samples. + Parameters: + - indices (array-like): The indices of the replay buffer to update. + - priorities (array-like): The new priorities to assign to the specified indices. Returns: - None + None """ - ind = info["indices"] - priorities = ( - info["priorities"] - if "priorities" in info - else torch.tensor([1.0] * len(info["indices"])) - ) - self.max_priority = max(priorities.max(), self.max_priority) - self.tree.batch_set(ind, priorities) + self.tree.batch_set(indices, priorities) def flush(self): """