From 4297a7a5d982077d349e94faf7ceb0852fb2dc66 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Wed, 15 Feb 2023 14:49:26 -0700 Subject: [PATCH] edits on default kwargs to make train more simple --- phygnn/model_interfaces/tf_model.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/phygnn/model_interfaces/tf_model.py b/phygnn/model_interfaces/tf_model.py index 50dea86..b74897e 100644 --- a/phygnn/model_interfaces/tf_model.py +++ b/phygnn/model_interfaces/tf_model.py @@ -385,8 +385,10 @@ def train_model(self, features, labels, epochs=100, shuffle=True, fit_kwargs : dict | None kwargs for tensorflow.keras.models.fit """ - if parse_kwargs is None: - parse_kwargs = {} + + parse_kwargs = parse_kwargs or {} + fit_kwargs = fit_kwargs or {} + stop_kwargs = stop_kwargs or {'monitor': 'val_loss', 'patience': 10} if (isinstance(features, np.ndarray) and features.shape[-1] == self.feature_dims): @@ -405,12 +407,7 @@ def train_model(self, features, labels, epochs=100, shuffle=True, logger.warning(msg) warn(msg) - if fit_kwargs is None: - fit_kwargs = {} - if early_stop: - if stop_kwargs is None: - stop_kwargs = {'monitor': 'val_loss', 'patience': 10} early_stop = tf.keras.callbacks.EarlyStopping(**stop_kwargs) callbacks = fit_kwargs.pop('callbacks', None) if callbacks is None: