Skip to content

Commit

Permalink
edits on default kwargs to make train more simple
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Feb 15, 2023
1 parent 151be8b commit 4297a7a
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions phygnn/model_interfaces/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,10 @@ def train_model(self, features, labels, epochs=100, shuffle=True,
fit_kwargs : dict | None
kwargs for tensorflow.keras.models.fit
"""
if parse_kwargs is None:
parse_kwargs = {}

parse_kwargs = parse_kwargs or {}
fit_kwargs = fit_kwargs or {}
stop_kwargs = stop_kwargs or {'monitor': 'val_loss', 'patience': 10}

if (isinstance(features, np.ndarray)
and features.shape[-1] == self.feature_dims):
Expand All @@ -405,12 +407,7 @@ def train_model(self, features, labels, epochs=100, shuffle=True,
logger.warning(msg)
warn(msg)

if fit_kwargs is None:
fit_kwargs = {}

if early_stop:
if stop_kwargs is None:
stop_kwargs = {'monitor': 'val_loss', 'patience': 10}
early_stop = tf.keras.callbacks.EarlyStopping(**stop_kwargs)
callbacks = fit_kwargs.pop('callbacks', None)
if callbacks is None:
Expand Down

0 comments on commit 4297a7a

Please sign in to comment.