-
Notifications
You must be signed in to change notification settings - Fork 115
TF Keras saves model in distribution
If the users want to use a distributions strategy to train a Keras model, they need enter the strategy scope to build and compile the model like:
with strategy.scope():
inputs = tf.keras.layers.Input(4,)
dense = tf.keras.layers.Dense(4)(inputs)
output = tf.keras.layers.Dense(1)(dense)
model = tf.keras.models.Model(inputs=inputs, outputs=output)
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['accuracy'])
The strategy.scope()
returns a _CurrentDitributionConext. When we enter the context, we will also enter the variable_scope
, variable_creator_scope
and device_scope
.
def __enter__(self):
# Allow this scope to be entered if this strategy is already in scope.
if distribution_strategy_context.has_strategy():
_require_cross_replica_or_default_context_extended(
self._context.strategy.extended)
self._same_scope_again_count += 1
else:
_push_per_thread_mode(self._context)
if self._var_scope:
self._var_scope.__enter__()
self._var_creator_scope.__enter__()
if self._device_scope:
self._device_scope.__enter__()
return self._context.strategy
-
variable_scope
: A context manager for defining ops that creates variables (layers). -
variable_creator_scope
: Scope which defines a variable creation function to be used by variable(). -
device_scope
: Context-manager to force placement of operations and Tensors on a device.
Now, we focus on the device_scope
which determines the variables placement. After entering the device_scope
, Keras will get the local device information in the current process and set the device information to _EagerDeviceContext. We can view the device name from the context by the following code snippet:
from tensorflow.python.eager import context as _context
with strategy.scope():
_ctx = _context._context
print(_ctx)
print(_ctx._thread_local_data.device_name)
The _ctx
is the distribution execution context and it will display all available devices including CPU and CPU. The _ctx._thread_local_data.device_name
is the device name to run the current thread.
After setting the execution context, Keras will execute the ops on the device whose device name is _ctx._thread_local_data.device_name
. Before calling each layer in a Keras model, Keras need to execute some Tensorflow ops in add_weight
and initializer
to create ResourceVariable
variables for each layer. So, the variables will be placed on the device in the execution context. As is known, the Tensorflow ops are executed in CPP runtime. The execution code to execute ResourceVariable
creation ops is generated by source CPP file "resource_variable_ops.cc", that is "tensorflow_core/python/ops/gen_resource_variable_ops.py". The ops execution code is following:
_ctx = _context._context or _context.context()
if _ctx is not None and _ctx._thread_local_data.is_eager:
try:
_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(
_ctx._context_handle, _ctx._thread_local_data.device_name,
"VarHandleOp", name, _ctx._post_execution_callbacks, "container",
container, "shared_name", shared_name, "dtype", dtype, "shape", shape)
return _result
except _core._FallbackException:
try:
return var_handle_op_eager_fallback(
container=container, shared_name=shared_name, dtype=dtype,
shape=shape, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
except _core._NotOkStatusException as e:
if name is not None:
message = e.message + " name: " + name
else:
message = e.message
_six.raise_from(_core._status_to_exception(e.code, message), None)
we can see that the runtime firstly get the context constructed by strategy.scope and then executes the ops on the device _ctx._thread_local_data.device_name
.
The CentralStorageStrategy
uses the ParameterServerStrategyExtended
extends StrategyExtendedV2
.
The default device is None in the StrategyExtendedV2
and ParameterServerStrategyExtended
does not set the default device for the strategy. So, automatic placement is performed for CentralStorageStrategy
. We can verify the conclusion by the follow snippet.
import tensorflow as tf
from tensorflow.python.eager import context as _context
_ctx = _context._context
print("The context before entering startegy scope is:")
print(_ctx)
print("current device name: ", _ctx._thread_local_data.device_name)
strategy = tf.distribute.experimental.CentralStorageStrategy()
with strategy.scope():
_ctx = _context._context
print("The context in the startegy scope is:")
print(_ctx)
print("current device name: ", _ctx._thread_local_data.device_name)
MultiWorkerMirroredStrategy
uses CollectiveAllReduceExtended
which will set _default_device
for the strategy. CollectiveAllReduceExtended
will initialize the strategy according to the cluster_spec
which is generated by TF_CONFIG
and set the current device for the strategy. So, we will initialize all variables on each device using MultiWorkerMirroredStrategy
. To verify the conclusion, we can execute the snippet:
import os
import json
import tensorflow as tf
from tensorflow.python.eager import context as _context
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'chief': ["localhost:12347"],
'worker': ["localhost:12348"]
},
'task': {'type': 'worker', 'index': 0}
})
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
_ctx = _context._context or _context.context()
print(_ctx)
print("current device name: ", _ctx._thread_local_data.device_name)
Keras trains the model by model.fit()
which calls model_iteration
in "training_arrays.py". The execution steps in model_iteration
are:
- Make execution function
fn
e.g.model.train_on_batch
for per replica. - Run
fn
for each mini-batch samples once per replica.tf.distribute.Strategy.experimental_run_v2
will create_MirroredReplicaThread
s for each replica/worker device to runfn
under its owner device scope. - Collect all outputs on each replica by the strategy.
def _make_execution_function_without_cloning(model, mode):
"""Creates a function to run one step of distributed model execution."""
strategy = model._distribution_strategy
with strategy.scope():
per_replica_function = _make_replica_execution_function(model, mode)
def distributed_function(input_fn):
"""A single step of the distributed execution across replicas."""
x, y, sample_weights = input_fn()
# Call `Model.{train,test,predict}_on_batch` on every replica passing
# PerReplicas as arguments. On every replica inside this call, each
# PerReplica object will return the value for that replica. The outputs
# are PerReplicas too.
outputs = strategy.experimental_run_v2(
per_replica_function, args=(x, y, sample_weights))
# Out of PerReplica outputs reduce or pick values to return.
all_outputs = unwrap_outputs(
strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
return all_outputs
...
execution_function = distributed_function
return execution_function
The steps in tf.saved_model.save
are:
- Construct a
_AugmentedGraphView
base the Keras model instance. - Generate signatures for the model instance.
- Gather all tensors in the graph view and group the tensors by device name.
- Save all tensors by
ops.save
. - Save the graph view to
saved_model.pb
file.
In this section, we focus on how Keras save variables to files in the supported distribution strategy. After Keras places all variables on the devices, all variables contain the placement device information.
import tensorflow as tf
with tf.device(''):
v = tf.Variable(tf.zeros([10, 10]))
print(v.device)
tf.saved_model.save
will collect all variables in the current process and group variables to different shards by its device name. Then Keras will run io_ops.save()
to save the shard to shard file under device scope. In the following snippet, we simulate that Keras saves multiple variable groups and each group represents variables on a device.
import os
import uuid
import tensorflow as tf
from tensorflow.python import ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.lib.io import file_io
def save_tensor(tensors, export_dir):
filename_tensor = export_dir
tensor_names = []
tensor_slices = []
for tensor in tensors:
tensor_names.append(tensor.name)
tensor_slices.append('')
io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,tensors)
def save_checkpoint_with_shards(shard_tensor, checkpoint_prefix):
file_io.recursive_create_dir(os.path.dirname(checkpoint_prefix))
_SHARDED_SUFFIX = "_temp_%s/part" % uuid.uuid4().hex
tmp_checkpoint_prefix = string_ops.string_join(
[checkpoint_prefix, _SHARDED_SUFFIX])
sharded_prefixes = []
num_shards = len(shard_tensor)
for shard_id,tensors in shard_tensor:
sharded_filename = gen_io_ops.sharded_filename(tmp_checkpoint_prefix, shard_id, num_shards)
save_tensor(tensors, sharded_filename)
sharded_prefixes.append(sharded_filename)
gen_io_ops.merge_v2_checkpoints(sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
shard_tensor = [(0, [tf.Variable([1,2,3], name='t1-0'),tf.Variable([1,2,3], name='t1-1')]),
(1, [tf.Variable([4,5,6], name='t2')]),
(2, [tf.Variable([7,8,9], name='t3')]),
]
save_checkpoint_with_shards(shard_tensor, checkpoint_prefix='ckpt-0/variable')
restore_variable = io_ops.restore_v2('ckpt-0/variable', ['t1-0:0'], [''], [tf.int32])[0]
print('Restore tensor is:', restore_variable)
The saved file contents is
|-- ckpt-0
|-- variable.data-00000-of-00003
|-- variable.data-00001-of-00003
|-- variable.data-00002-of-00003
|-- variable.index
In "variable.data-00000-of-00003", the "00000" is the first shard and "00003" is the number of shards.