Skip to content

Commit

Permalink
add default.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fwrrong authored Nov 11, 2023
1 parent af48d95 commit 296fac6
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions examples/JAX/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2021 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Default Hyperparameter configuration."""

import ml_collections


def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()

# As defined in the `models` module.
config.model = 'ResNet50'
# `name` argument of tensorflow_datasets.builder()
config.dataset = 'imagenet2012:5.*.*'

config.learning_rate = 0.1
config.warmup_epochs = 5.0
config.momentum = 0.9
config.batch_size = 128
config.shuffle_buffer_size = 16 * 128
config.prefetch = 10

config.num_epochs = 100.0
config.log_every_steps = 100

config.cache = False
config.half_precision = False

# If num_train_steps==-1 then the number of training steps is calculated from
# num_epochs using the entire dataset. Similarly for steps_per_eval.
config.num_train_steps = -1
config.steps_per_eval = -1
return config


def metrics():
return [
'train_loss',
'eval_loss',
'train_accuracy',
'eval_accuracy',
'steps_per_second',
'train_learning_rate',
]

0 comments on commit 296fac6

Please sign in to comment.