Skip to content

Commit

Permalink
DroQ Critic
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Dec 6, 2024
1 parent 355b426 commit c4113b6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
4 changes: 1 addition & 3 deletions cares_reinforcement_learning/networks/DroQ/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_sizes: list[int] | None = None,
):
input_size = observation_size + num_actions
if hidden_sizes is None:
hidden_sizes = [256, 256]
hidden_sizes = [256, 256]

# Q1 architecture
# pylint: disable-next=invalid-name
Expand Down
10 changes: 10 additions & 0 deletions cares_reinforcement_learning/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def __init__(
input_size: int,
hidden_sizes: list[int],
output_size: int | None,
dropout_layer: Callable[..., nn.Module] | str | None = None,
dropout_layer_args: dict[str, Any] | None = None,
norm_layer: Callable[..., nn.Module] | str | None = None,
norm_layer_args: dict[str, Any] | None = None,
hidden_activation_function: Callable[..., nn.Module] | str = nn.ReLU,
Expand All @@ -26,13 +28,18 @@ def __init__(
output_activation_args: dict[str, Any] | None = None,
):
super().__init__()
if dropout_layer_args is None:
dropout_layer_args = {}
if norm_layer_args is None:
norm_layer_args = {}
if hidden_activation_function_args is None:
hidden_activation_function_args = {}
if output_activation_args is None:
output_activation_args = {}

if isinstance(dropout_layer, str):
dropout_layer = get_pytorch_module_from_name(dropout_layer)

if isinstance(norm_layer, str):
norm_layer = get_pytorch_module_from_name(norm_layer)

Expand All @@ -51,6 +58,9 @@ def __init__(
for next_size in hidden_sizes:
layers.append(nn.Linear(input_size, next_size))

if dropout_layer is not None:
layers.append(dropout_layer(**dropout_layer_args))

if norm_layer is not None:
layers.append(norm_layer(next_size, **norm_layer_args))

Expand Down

0 comments on commit c4113b6

Please sign in to comment.