From 0fdc1476277109c31a5d3fac5e1798d0377921f8 Mon Sep 17 00:00:00 2001 From: wangermeng2021 Date: Tue, 29 Jun 2021 17:00:37 +0800 Subject: [PATCH] Fix bugs on saving model. --- README.md | 6 ++++++ model/PVT.py | 8 ++++++++ train.py | 10 +++++----- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 1b3943e..6c7ccf1 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,14 @@ # PVT-tensorflow2 +[![Python 3.7](https://img.shields.io/badge/Python-3.7-3776AB)](https://www.python.org/downloads/release/python-360/) +[![TensorFlow 2.4](https://img.shields.io/badge/TensorFlow-2.4-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.2.0) + A Tensorflow2.x implementation of Pyramid Vision Transformer as described in [Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions](https://arxiv.org/abs/2102.12122) ## Update Log +[2021-06-29] +* Fix bug on saving model + [2021-03-20] * Add PVT-tiny,PVT-small,PVT-medium,PVT-large. diff --git a/model/PVT.py b/model/PVT.py index 5fbb749..849ccf3 100644 --- a/model/PVT.py +++ b/model/PVT.py @@ -101,6 +101,14 @@ def build(self, input_shape): def call(self, x): return x+self.pos_embed + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'img_len': self.img_len, + }) + return config + def get_pvt(img_size,num_classes,block_depth,mlp_ratio,drop_path_rate,first_level_patch_size,embed_dims,num_heads,sr_ratio,attention_drop_rate,drop_rate): block_drop_path_rate = np.linspace(0, drop_path_rate, sum(block_depth)) block_depth_index = 0 diff --git a/train.py b/train.py index 48aa21c..bd4e7f9 100644 --- a/train.py +++ b/train.py @@ -49,18 +49,18 @@ def main(args): os.makedirs(args.checkpoints) # lr_cb = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=20, verbose=1, min_lr=0) lr_cb = tf.keras.callbacks.LearningRateScheduler(lr_scheduler) - model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath=args.checkpoints+'/best_weight_{epoch}_{accuracy:.3f}_{val_accuracy:.3f}', - monitor='val_accuracy',mode='max', - verbose=1,save_best_only=True) + model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath=args.checkpoints+'/best_weight_{epoch}_{accuracy:.3f}_{val_accuracy:.3f}.h5', + monitor='val_accuracy',mode='max', + verbose=1,save_best_only=True,save_weights_only=True) cbs=[lr_cb, - # model_checkpoint_cb + model_checkpoint_cb ] model.compile(optimizer,loss_object,metrics=["accuracy"],) model.fit(train_generator, validation_data=val_generator, epochs=args.epochs, callbacks=cbs, - verbose=2, + verbose=1, ) if __name__== "__main__":