diff --git a/tests/test_ae.py b/tests/test_ae.py index c391ff95..c45fb7c4 100644 --- a/tests/test_ae.py +++ b/tests/test_ae.py @@ -1,5 +1,4 @@ import inspect -import os import pytest import torch @@ -8,32 +7,7 @@ from cares_reinforcement_learning.encoders.autoencoder_factory import AEFactory from cares_reinforcement_learning.encoders.configurations import AEConfig, BurgessConfig -IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" - -@pytest.mark.skipif(not IN_GITHUB_ACTIONS, reason="Running more complex test locally") -def test_ae_factory(): - factory = AEFactory() - - ae_configurations = {} - for name, cls in inspect.getmembers(configurations, inspect.isclass): - if issubclass(cls, AEConfig) and cls != AEConfig and cls != BurgessConfig: - name = name.replace("Config", "") - ae_configurations[name] = cls - - for ae, config in ae_configurations.items(): - - observation_size = (9, 84, 84) - - config = config(latent_dim=100) - - ae = factory.create_autoencoder( - observation_size=observation_size, config=config - ) - assert ae is not None, f"{ae} was not created successfully" - - -@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.") def test_ae(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 3dbe7220..dfa159b7 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,5 +1,4 @@ import inspect -import os from random import randrange import numpy as np @@ -9,28 +8,6 @@ from cares_reinforcement_learning.util import NetworkFactory, configurations from cares_reinforcement_learning.util.configurations import AlgorithmConfig -# IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" - - -# @pytest.mark.skipif(not IN_GITHUB_ACTIONS, reason="Running more complex test locally") -# def test_network_factory(): -# factory = NetworkFactory() - -# algorithm_configurations = {} -# for name, cls in inspect.getmembers(configurations, inspect.isclass): -# if issubclass(cls, AlgorithmConfig) and cls != AlgorithmConfig: -# name = name.replace("Config", "") -# algorithm_configurations[name] = cls - -# for algorithm, config in algorithm_configurations.items(): -# config = config() -# observation_size = (9, 84, 84) if config.image_observation else 10 -# action_num = 2 -# network = factory.create_network( -# observation_size=observation_size, action_num=action_num, config=config -# ) -# assert network is not None, f"{algorithm} was not created successfully" - def _policy_buffer( memory_buffer,