Skip to content

Commit

Permalink
Add divergence to logs
Browse files Browse the repository at this point in the history
Divergence is the difference between training and validation to
allow a clear value to indicate the difference between the two
in the logs.
  • Loading branch information
rockerBOO committed Jan 12, 2025
1 parent 264167f commit 4c61adc
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,14 +1418,16 @@ def remove_model(old_ckpt_name):

if is_tracking:
logs = {
"loss/validation/step/current": current_loss,
"loss/validation/step_current": current_loss,
"val_step": (epoch * validation_steps) + val_step,
}
accelerator.log(logs, step=global_step)

if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
logs = {
"loss/validation/step/average": val_step_loss_recorder.moving_average,
"loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)

Expand Down Expand Up @@ -1485,7 +1487,12 @@ def remove_model(old_ckpt_name):

if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average
logs = {"loss/validation/epoch_average": avr_loss, "epoch": epoch + 1}
loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss
logs = {
"loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence,
"epoch": epoch + 1
}
accelerator.log(logs, step=global_step)

# END OF EPOCH
Expand Down

0 comments on commit 4c61adc

Please sign in to comment.