Getting started with TensorFlow Large Model Support (TFLMS)
The TensorFlow Large Model Support (TFLMS) provides an approach to training large models that cannot be fit into GPU memory. It takes a computational graph that is defined by users, and automatically adds swap-in and swap-out nodes for transferring tensors from GPUs to the host and vice versa. The computational graph is statically modified. Hence, it needs to be done before a session starts.
How to use TFLMS
TFLMS needs to know some information about user-defined models. One requirement for a user-defined model is that it must have scopes for the optimizers or solvers.
Enabling LMS for a model depends on how users write their training. The following guidelines cover three ways to train:
Session-based training
To train a model that uses session-based training, follow these steps:
- Define optimizer or solver
scopes
with tf.name_scope('adam_optimizer'): train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
- Define an LMS object and run
it
from tensorflow.contrib.lms import LMS lms_obj = LMS({'adam_optimizer'}) lms_obj.run(graph=tf.get_default_graph())
You must add these lines before you start a training session, for example:
- Before you insert the LMS
code
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) batch = mnist.train.next_batch(50) train_step.run(feed_dict={x: batch[0], y_: batch[1]})
- After you insert the LMS
code
from tensorflow.contrib.lms import LMS lms_obj = LMS({'adam_optimizer'}) lms_obj.run(graph=tf.get_default_graph()) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) batch = mnist.train.next_batch(50) train_step.run(feed_dict={x: batch[0], y_: batch[1]})
For a working example of LMS integration with Session-based training, see
/opt/DL/tensorflow/lib/python*/site-packages/tensorflow/contrib/lms/examples/mnist_deep_lms.py
,
which is an LMS enabled version of
/opt/DL/tensorflow/lib/python*/site-packages/tensorflow/examples/tutorials/mnist/mnist_deep.py
.Estimator-based training
To train a model that uses estimator-based training, follow these steps:
- Define optimizer or solver
scopes
with tf.name_scope('adam_optimizer'): optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) train_op = optimizer.minimize( loss=loss, global_step=tf.train.get_global_step())
- Define the
LMSHook
# Hook for Large Model Support from tensorflow.contrib.lms import LMSHook # LMSHook and LMS share the same set of parameters. Here we just # use the default keyword arguments. lms_hook = LMSHook({'adam_optimizer'})
- Add the
LMSHook
into the Estimator hook listmnist_classifier.train( input_fn=train_input_fn, steps=20000 hooks=[logging_hook, lms_hook])
For a working example of LMS integration with Estimator-based training, see
/opt/DL/tensorflow/lib/python*/site-packages/tensorflow/contrib/lms/examples/cnn_mnist_lms.py
,
which is an LMS enabled version of
/opt/DL/tensorflow/lib/python*/site-packages/tensorflow/examples/tutorials/layers/cnn_mnist.py
.tf.keras-based training
To train a model that uses tf.keras
-based training, follow these steps:
- Define the
LMSKerasCallback
from tensorflow.contrib.lms import LMSKerasCallback # LMSKerasCallback and LMS share a set of keyword arguments. Here we just # use the default options. lms_callback = LMSKerasCallback()
- Pass the callback to the Keras
fit
orfit_generator
function.model.fit_generator(generator=training_gen, callbacks=[lms_callback])
Required Parameters for LMS or LMSHook
- graph
- The graph that is modified for LMS. This graph is the graph of user-defined neural network. (not required in LMSHook).
- optimizer_scopes
- Scopes for the optimizers or solvers.
Optional parameters for LMS or LMSHook
- starting_scope
- Tensors that are reachable from the operations in this scope are swapped for LMS. Set this
option to the scope of the first layer if you want to modify the whole graph. Default
None
. - starting_op_names
- Tensors that are reachable from the operations with these names are swapped for LMS. Default
None
. - excl_scopes
- A set of scopes for operations whose tensors are not swapped out to the host. Default
empty
. - incl_scopes
- A set of scopes for operations whose tensors are swapped out to the host. Default
empty
. - excl_types
- A set of types for operations whose tensors are not swapped out to the host. Default
empty
. - incl_types
- A set of types for operations whose tensors are swapped out to the host. Default
empty
. - n_tensors
- The number of tensors for LMS, counting from the
starting_scope
. To turn off LMS, setn_tensors
to0
. Default-1
(all reachable tensors are swapped for LMS). - lb
- Lower bound value for LMS. A tensor is swapped in during the backward phase at least
lb
nodes before it in the graph. Default1
. - ub
- Upper bound value for LMS. Default
10000
. - fuse_swapins
- Fuse "close" swap-in operations into one operation. This action can improve the performance.
Default
False
. - ctrld_strategy
- Two strategies to find control dependency ops for swap in ops:
chain_rule
anddirect_order
.chain_rule
strategy starts from a forward operation, goes forward, and finds a corresponding backward operation to be a control dependency operation.direct_order
strategy directly gets a backward op in the topological order to be a control dependency operation. Both strategies depend onlb
andub
to choose a control dependency operation. While thedirect_order
is more exact thanchain_rule
in relation tolb
andub
, it experimentally often results in smaller maximum batch size thanchain_rule
. Defaultchain_rule
. - swap_branches
- If True, LMS swaps tensors in branches in the forward phase. Default
False
. - branch_threshold
- If
swap_branches
is enabled and the topological-sort distance between the consuming operation and generating operation of a tensor is greater thanbranch_threshold
, then swap the tensor. Default0
. - debug
- Debug mode for LMS. Default
False
. - debug_level
- Debug level for LMS (1 or 2). Default
1
.
Scaling tips
TensorFlow sets a limit on the amount of memory that is allocated on the CUDA host (CPU) side.
The limit is often not high enough to act as a tensor swap space when you are swapping a large
amount of data or when you are using multiple GPUs in a multi-tower fashion with a tower for each
GPU as described in the TensorFlow documentation. Failure to set this limit higher results in out of
memory errors like this error: Allocator (cuda_host_bfc) ran out of memory trying to
allocate
. Note the cuda_host_bfc
allocator is mentioned rather than a GPU
allocator.
A good practice would be to start with a value that is four times the memory capacity of the GPUs
times the number of GPUs that is used. For example, if you have four 16-GB GPUs in a system and uses
all four in a training run, TF_CUDA_HOST_MEM_LIMIT_IN_MB
can be set to 262144 and
adjust from there as needed. (4 x 16384 (16 GB as MB) x 4 GPUs) = 262144 MB.
Multi GPU scaling tips
To achieve better scaling performance with LMS on multiple GPUs, update the training script to
use PowerAI Distributed Deep Learning and run the training script with the ddlrun
command. Additionally, if you are running on a single system without an InfiniBand set up, the
--mpiarg -pami_noib
parameter must be added to the ddlrun
command
line, for example:
ddlrun --mpiarg -pami_noib -H host1 python train_model.py
For more information about using ddlrun
, see
/opt/DL/ddl/doc/README.md
and
/opt/DL/ddl-tensorflow/doc/README.md
. You can also find information in this blog:
Improved Ease of Use for DDL in PowerAI
Performance tuning LMS
After you enable LMS graph modification in your code, find the combination of tuning parameters that gives the fastest training time and best accuracy with your model. The goal of the performance tuning is to swap out enough tensors to allow your training to run without causing out of memory errors while not swapping too many so that the extra swapping communication overhead degrades performance.
The two tuning parameters to focus on are n_tensors
and lb
.
Since n_tensors
controls the number of tensors that are swapped, the higher this
n_tensors
is set, the lower the peak GPU memory usage is. The lb
controls how soon the tensor is swapped back in before use. A low value of lb
can
make the training on the GPU pause and wait while the swap in finishes and degrades performance. A
higher value of lb
allows the tensor swap in to finish before it is needed and
allows training to run without pause. The downside to swapping in too early is that more tensors are
than in GPU memory at any point in time, resulting in higher peak GPU memory usage.
The tuning thus becomes finding the correct balance between n_tensors
and
lb
that provides the best performance for given model. To start the performance
tuning, consider setting n_tensors
to -1, which swaps all reachable tensors. Set
lb
to the default 1, which is the last possible swap in. If
tf.logging
verbosity is set to tf.logging.INFO
, LMS outputs a log
statement with a count of the number of tensors swapped. It is useful to run with
n_tensors=-1
for the first run to find this maximum value and then adjust it
downward. If your model has branches like some UNet models do, consider setting
swap_branches=True
and tuning the branch threshold as well. While you tune the
parameters, it is often useful to surface the LMS parameters in the training script as command-line
parameters or configuration file properties.
By default LMS analyzes your graph to find the starting operations to use for finding tensor swap
candidates. You can bypass this analysis by placing your starting operations in a named scope and
providing the scope on the starting_scope
parameter, or by providing the names of
the starting operations on the starting_op_names
parameter. This action can speed
up repeated runs of LMS during tuning. Furthermore, you can enable debug=True
and
debug_level=1
and LMS prints the name and type of the starting operations that it
finds. These names can be passed in on the starting_op_names
parameter on
subsequent runs.
It is recommended that you start with tuning training on a single GPU before you enable your code for multi-GPU with DDL.
LMS with saved models
Both TensorFlow and Keras have various ways to save models. Some of these methods save the model or graph definition and some methods save only the weights. Whether you need to enable large model support on the loaded model depends on several factors, such as if you are loading the model for further training or loading the model for further inferencing, as well as how the model was saved.
If TensorFlow MetaGraphs or SavedModels are saved after LMS adds swapping nodes to the model, the loaded model will contain swapping nodes. If only the model weights are saved, and are restored onto a model that is built using code, then the model has only LMS swapping nodes if LMS is run again on the model.
Keras models that are saved with tf.keras.models.save_model
do not have LMS
swapping nodes in them. If swapping is required in the loaded model, the LMSKerasCallback can be
passed to the load tf.keras.models.load_model
API. For example:
from tensorflow.contrib.lms import LMSKerasCallback
lms_callback = LMSKerasCallback()
model = tf.keras.models.load_model(filepath, callbacks=[lms_callback])