diff --git a/cares_reinforcement_learning/util/helpers.py b/cares_reinforcement_learning/util/helpers.py index e3f3d3d2..c6c1a9e0 100644 --- a/cares_reinforcement_learning/util/helpers.py +++ b/cares_reinforcement_learning/util/helpers.py @@ -1,8 +1,16 @@ import torch +import numpy as np +import random import pandas as pd import matplotlib.pyplot as plt +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + def plot_reward_curve(data_reward): data = pd.DataFrame.from_dict(data_reward) diff --git a/example/example_training_loops.py b/example/example_training_loops.py index c8ac76f4..c61b8d30 100644 --- a/example/example_training_loops.py +++ b/example/example_training_loops.py @@ -8,6 +8,7 @@ from cares_reinforcement_learning.util import Record from cares_reinforcement_learning.util import EnvironmentFactory from cares_reinforcement_learning.util import arguement_parser as ap +from cares_reinforcement_learning.util import helpers as hlp import cares_reinforcement_learning.train_loops.policy_loop as pbe import cares_reinforcement_learning.train_loops.value_loop as vbe @@ -22,11 +23,6 @@ from pathlib import Path from datetime import datetime -def set_seed(seed): - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) - def main(): parser = ap.create_parser() args = vars(parser.parse_args()) # converts to a dictionary @@ -53,7 +49,7 @@ def main(): training_iterations = args['number_training_iterations'] for training_iteration in range(0, training_iterations): logging.info(f"Training iteration {training_iteration+1}/{training_iterations} with Seed: {args['seed']}") - set_seed(args['seed']) + hlp.set_seed(args['seed']) env.set_seed(args['seed']) logging.info(f"Algorithm: {args['algorithm']}")