Skip to content

Commit

Permalink
Full algorithm and ae tests running now
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Aug 16, 2024
1 parent 1fbccf0 commit 6c8c072
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 49 deletions.
26 changes: 0 additions & 26 deletions tests/test_ae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import os

import pytest
import torch
Expand All @@ -8,32 +7,7 @@
from cares_reinforcement_learning.encoders.autoencoder_factory import AEFactory
from cares_reinforcement_learning.encoders.configurations import AEConfig, BurgessConfig

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"


@pytest.mark.skipif(not IN_GITHUB_ACTIONS, reason="Running more complex test locally")
def test_ae_factory():
factory = AEFactory()

ae_configurations = {}
for name, cls in inspect.getmembers(configurations, inspect.isclass):
if issubclass(cls, AEConfig) and cls != AEConfig and cls != BurgessConfig:
name = name.replace("Config", "")
ae_configurations[name] = cls

for ae, config in ae_configurations.items():

observation_size = (9, 84, 84)

config = config(latent_dim=100)

ae = factory.create_autoencoder(
observation_size=observation_size, config=config
)
assert ae is not None, f"{ae} was not created successfully"


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.")
def test_ae():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
23 changes: 0 additions & 23 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import os
from random import randrange

import numpy as np
Expand All @@ -9,28 +8,6 @@
from cares_reinforcement_learning.util import NetworkFactory, configurations
from cares_reinforcement_learning.util.configurations import AlgorithmConfig

# IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"


# @pytest.mark.skipif(not IN_GITHUB_ACTIONS, reason="Running more complex test locally")
# def test_network_factory():
# factory = NetworkFactory()

# algorithm_configurations = {}
# for name, cls in inspect.getmembers(configurations, inspect.isclass):
# if issubclass(cls, AlgorithmConfig) and cls != AlgorithmConfig:
# name = name.replace("Config", "")
# algorithm_configurations[name] = cls

# for algorithm, config in algorithm_configurations.items():
# config = config()
# observation_size = (9, 84, 84) if config.image_observation else 10
# action_num = 2
# network = factory.create_network(
# observation_size=observation_size, action_num=action_num, config=config
# )
# assert network is not None, f"{algorithm} was not created successfully"


def _policy_buffer(
memory_buffer,
Expand Down

0 comments on commit 6c8c072

Please sign in to comment.