Skip to content

Commit

Permalink
Save model at eval loops not training steps
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Nov 10, 2024
1 parent 30eb08f commit 21147eb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 49 deletions.
2 changes: 0 additions & 2 deletions cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ class TrainingConfig(SubscriptableClass):
Attributes:
seeds (list[int]): list of random seeds for reproducibility. Default is [10].
plot_frequency (int]): Frequency at which to plot training progress. Default is 100.
checkpoint_frequency (int]): Frequency at which to save model checkpoints. Default is 100.
number_steps_per_evaluation (int]): Number of steps per evaluation. Default is 10000.
number_eval_episodes (int]): Number of episodes to evaluate during training. Default is 10.
"""

seeds: list[int] = [10]
plot_frequency: int = 100
checkpoint_frequency: int = 100
number_steps_per_evaluation: int = 10000
number_eval_episodes: int = 10

Expand Down
59 changes: 12 additions & 47 deletions cares_reinforcement_learning/util/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
algorithm: str,
task: str,
plot_frequency: int = 10,
checkpoint_frequency: int | None = None,
network: nn.Module | None = None,
) -> None:

Expand All @@ -44,12 +43,6 @@ def __init__(
self.task = task

self.plot_frequency = plot_frequency
self.checkpoint_frequency = checkpoint_frequency

if self.checkpoint_frequency is None:
logging.warning(
"checkpoint_frequency not provided. Model will not be auto-saved and saving should be managed externally with save_model."
)

self.train_data_path = f"{self.directory}/data/train.csv"
self.train_data = (
Expand All @@ -63,18 +56,8 @@ def __init__(
if os.path.exists(self.eval_data_path)
else pd.DataFrame()
)
self.info_data_path = f"{self.directory}/data/info.csv"
self.info_data = (
pd.read_csv(self.info_data_path)
if os.path.exists(self.info_data_path)
else pd.DataFrame()
)

if (
not self.train_data.empty
or not self.eval_data.empty
or not self.info_data.empty
):
if not self.train_data.empty or not self.eval_data.empty:
logging.warning("Data files not empty. Appending to existing data")

self.network = network
Expand Down Expand Up @@ -102,18 +85,13 @@ def start_video(self, file_name, frame, fps=30):
def stop_video(self) -> None:
self.video.release()

def save_model(self, identifier):
self.network.save_models(f"{self.algorithm}-{identifier}", self.directory)
def save_model(self, file_name):
if self.network is not None:
self.network.save_models(f"{file_name}", self.directory)

def log_video(self, frame: np.ndarray) -> None:
self.video.write(frame)

def log_info(self, info: dict, display: bool = False) -> None:
self.info_data = pd.concat(
[self.info_data, pd.DataFrame([info])], ignore_index=True
)
self.save_data(self.info_data, self.info_data_path, info, display=display)

def log_train(self, display: bool = False, **logs) -> None:
self.log_count += 1

Expand All @@ -132,29 +110,15 @@ def log_train(self, display: bool = False, **logs) -> None:
20,
)

is_at_checkpoint = (self.checkpoint_frequency is not None) and (
self.log_count % self.checkpoint_frequency == 0
)

reward = logs["episode_reward"]

is_new_best_reward = reward > self.best_reward

if is_new_best_reward:
if reward > self.best_reward:
logging.info(
f"New highest reward of {reward} during training! Saving model..."
)
self.best_reward = reward

if self.network is not None:
if is_at_checkpoint:
self.network.save_models(
f"{self.algorithm}-checkpoint-{self.log_count}", self.directory
)
if is_new_best_reward:
logging.info(
f"New highest reward of {reward} during training! Saving models..."
)
self.network.save_models(
f"{self.algorithm}-highest-reward-training", self.directory
)
self.save_model(f"{self.algorithm}-highest-reward-training")

def log_eval(self, display: bool = False, **logs) -> None:
self.eval_data = pd.concat(
Expand All @@ -170,6 +134,8 @@ def log_eval(self, display: bool = False, **logs) -> None:
"eval",
)

self.save_model(f"{self.algorithm}-eval-{logs['total_steps']}")

def save_data(
self, data_frame: pd.DataFrame, path: str, logs: dict, display: bool = True
) -> None:
Expand Down Expand Up @@ -212,8 +178,7 @@ def save(self) -> None:
20,
)

if self.network is not None:
self.network.save_models(self.algorithm, self.directory)
self.save_model(self.algorithm)

def __initialise_directories(self) -> None:

Expand Down

0 comments on commit 21147eb

Please sign in to comment.