Skip to content

Commit

Permalink
use custom loss wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte committed Dec 19, 2023
1 parent 5468221 commit 0773390
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from cyclical_learning_rate import CyclicLR
from models import *
from utils import *
from loss import custom_loss
from loss import custom_loss_wrapper
from DataGenerator import DataGenerator

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -123,6 +123,7 @@ def train_dataGenerator(args):
n_features_pf = 6
n_features_pf_cat = 2
normFac = 1.
custom_loss = custom_loss_wrapper(normFac)
epochs = args.epochs
batch_size = args.batch_size
preprocessed = True
Expand Down Expand Up @@ -254,6 +255,7 @@ def train_loadAllData(args):
n_features_pf = 6
n_features_pf_cat = 2
normFac = 1.
custom_loss = custom_loss_wrapper(normFac)
epochs = args.epochs
batch_size = args.batch_size
preprocessed = True
Expand Down

0 comments on commit 0773390

Please sign in to comment.