Skip to content

Commit

Permalink
C51 initial files
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Jan 6, 2025
1 parent 4fd16e3 commit d72ec0c
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 0 deletions.
89 changes: 89 additions & 0 deletions cares_reinforcement_learning/algorithm/value/C51.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Original Paper: https://arxiv.org/pdf/1707.06887
"""

import copy
import logging
import os
from typing import Any

import numpy as np
import torch
import torch.nn.functional as F

from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.C51 import Network
from cares_reinforcement_learning.util.configurations import C51Config


class C51:
def __init__(
self,
network: Network,
config: C51Config,
device: torch.device,
):
self.type = "value"
self.device = device

self.network = network.to(device)
self.target_network = copy.deepcopy(self.network).to(self.device)

self.gamma = config.gamma

self.network_optimiser = torch.optim.Adam(
self.network.parameters(), lr=config.lr
)

def select_action_from_policy(self, state) -> float:
self.network.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
q_values = self.network(state_tensor)
action = torch.argmax(q_values).item()
self.network.train()
return action

def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
experiences = memory.sample_uniform(batch_size)
states, actions, rewards, next_states, dones, _ = experiences

# Convert into tensor
states_tensor = torch.FloatTensor(np.asarray(states)).to(self.device)
actions_tensor = torch.LongTensor(np.asarray(actions)).to(self.device)
rewards_tensor = torch.FloatTensor(np.asarray(rewards)).to(self.device)
next_states_tensor = torch.FloatTensor(np.asarray(next_states)).to(self.device)
dones_tensor = torch.LongTensor(np.asarray(dones)).to(self.device)

# Generate Q Values given state at time t and t + 1
q_values = self.network(states_tensor)
next_q_values = self.network(next_states_tensor)

best_q_values = q_values[torch.arange(q_values.size(0)), actions_tensor]
best_next_q_values = torch.max(next_q_values, dim=1).values

q_target = rewards_tensor + self.gamma * (1 - dones_tensor) * best_next_q_values

info = {}

# Update the Network
loss = F.mse_loss(best_q_values, q_target)
self.network_optimiser.zero_grad()
loss.backward()
self.network_optimiser.step()

info["loss"] = loss.item()

return info

def save_models(self, filepath: str, filename: str) -> None:
if not os.path.exists(filepath):
os.makedirs(filepath)

torch.save(self.network.state_dict(), f"{filepath}/{filename}_network.pht")
logging.info("models has been saved...")

def load_models(self, filepath: str, filename: str) -> None:
self.network.load_state_dict(torch.load(f"{filepath}/{filename}_network.pht"))
logging.info("models has been loaded...")
2 changes: 2 additions & 0 deletions cares_reinforcement_learning/networks/C51/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .network import Network
from .network import DefaultNetwork
91 changes: 91 additions & 0 deletions cares_reinforcement_learning/networks/C51/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
from torch import nn

from cares_reinforcement_learning.networks.common import MLP
from cares_reinforcement_learning.util.configurations import C51Config


class BaseNetwork(nn.Module):
def __init__(
self,
input_size: int,
output_size: int,
atom_size: int,
v_min: float,
v_max: float,
network: MLP | nn.Sequential,
):
super().__init__()

self.input_size = input_size
self.output_size = output_size
self.atom_size = atom_size

self.v_min = v_min
self.v_max = v_max

self.support = torch.linspace(self.v_min, self.v_max, self.atom_size)

self.network = network

def forward(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
output = self.network(state)

q_atoms = output.view(-1, self.output_size, self.atom_size)
dist = torch.softmax(q_atoms, dim=-1)
dist = dist.clamp(min=1e-3)

q = torch.sum(dist * self.support, dim=2)

return q, dist


# This is the default base network for DQN for reference and testing of default network configurations
class DefaultNetwork(BaseNetwork):
def __init__(
self,
observation_size: int,
num_actions: int,
):
hidden_sizes = [512, 512]
atom_size = 51
output_size = num_actions * atom_size

v_min = 0.0
v_max = 200.0

network = nn.Sequential(
nn.Linear(observation_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], output_size),
)
super().__init__(
input_size=observation_size,
output_size=output_size,
atom_size=atom_size,
v_min=v_min,
v_max=v_max,
network=network,
)


class Network(BaseNetwork):
def __init__(self, observation_size: int, num_actions: int, config: C51Config):

output_size = num_actions + config.num_atoms

network = MLP(
input_size=observation_size,
output_size=output_size,
config=config.network_config,
)
super().__init__(
input_size=observation_size,
output_size=output_size,
atom_size=config.num_atoms,
v_min=config.v_min,
v_max=config.v_max,
network=network,
)
15 changes: 15 additions & 0 deletions cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,21 @@ class DuelingDQNConfig(AlgorithmConfig):
advantage_stream_config: MLPConfig = MLPConfig(hidden_sizes=[512])


class C51Config(AlgorithmConfig):
algorithm: str = Field("C51", Literal=True)
lr: float = 1e-3
gamma: float = 0.99

exploration_min: float = 1e-3
exploration_decay: float = 0.95

num_atoms: int = 51
v_min: float = 0.0
v_max: float = 200.0

network_config: MLPConfig = MLPConfig(hidden_sizes=[512, 512])


###################################
# PPO Algorithms #
###################################
Expand Down

0 comments on commit d72ec0c

Please sign in to comment.