From 296fac60f7dc1b1c96e8b20c5d329d1aa5134b6b Mon Sep 17 00:00:00 2001 From: Weijun Fan <110766317+fwrrong@users.noreply.github.com> Date: Fri, 10 Nov 2023 22:07:34 -0500 Subject: [PATCH] add default.py --- examples/JAX/default.py | 70 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 examples/JAX/default.py diff --git a/examples/JAX/default.py b/examples/JAX/default.py new file mode 100644 index 00000000..57d6de66 --- /dev/null +++ b/examples/JAX/default.py @@ -0,0 +1,70 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2021 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Default Hyperparameter configuration.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + # As defined in the `models` module. + config.model = 'ResNet50' + # `name` argument of tensorflow_datasets.builder() + config.dataset = 'imagenet2012:5.*.*' + + config.learning_rate = 0.1 + config.warmup_epochs = 5.0 + config.momentum = 0.9 + config.batch_size = 128 + config.shuffle_buffer_size = 16 * 128 + config.prefetch = 10 + + config.num_epochs = 100.0 + config.log_every_steps = 100 + + config.cache = False + config.half_precision = False + + # If num_train_steps==-1 then the number of training steps is calculated from + # num_epochs using the entire dataset. Similarly for steps_per_eval. + config.num_train_steps = -1 + config.steps_per_eval = -1 + return config + + +def metrics(): + return [ + 'train_loss', + 'eval_loss', + 'train_accuracy', + 'eval_accuracy', + 'steps_per_second', + 'train_learning_rate', + ]