Getting started with TensorFlow Large Model Support (TFLMS)

The TensorFlow Large Model Support (TFLMS) provides an approach to training large models that cannot be fit into GPU memory. It takes a computational graph that is defined by users, and automatically adds swap-in and swap-out nodes for transferring tensors from GPUs to the host and vice versa. The computational graph is statically modified. Hence, it needs to be done before a session starts.

How to use TFLMS

TFLMS needs to know some information about user-defined models. One requirement for a user-defined model is that it must have scopes for the optimizers or solvers.

Enabling LMS for a model depends on how users write their training. The following guidelines cover three ways to train:

Session-based training

To train a model that uses session-based training, follow these steps:

  1. Define optimizer or solver scopes
    with tf.name_scope('adam_optimizer'):
    	train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  2. Define an LMS object and run it
    from tensorflow.contrib.lms import LMS
    lms_obj = LMS({'adam_optimizer'})
    lms_obj.run(graph=tf.get_default_graph())

You must add these lines before you start a training session, for example:

  • Before you insert the LMS code
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
    	batch = mnist.train.next_batch(50)
    	train_step.run(feed_dict={x: batch[0], y_: batch[1]})
  • After you insert the LMS code
    from tensorflow.contrib.lms import LMS
    lms_obj = LMS({'adam_optimizer'})
    lms_obj.run(graph=tf.get_default_graph())
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
    	batch = mnist.train.next_batch(50)
    	train_step.run(feed_dict={x: batch[0], y_: batch[1]})

For a working example of LMS integration with Session-based training, see /opt/DL/tensorflow/lib/python*/site-packages/tensorflow/contrib/lms/examples/mnist_deep_lms.py, which is an LMS enabled version of

/opt/DL/tensorflow/lib/python*/site-packages/tensorflow/examples/tutorials/mnist/mnist_deep.py.

Estimator-based training

To train a model that uses estimator-based training, follow these steps:

  1. Define optimizer or solver scopes
    with tf.name_scope('adam_optimizer'):
          optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
          train_op = optimizer.minimize(
            loss=loss,
    	global_step=tf.train.get_global_step())
  2. Define the LMSHook
    # Hook for Large Model Support
    from tensorflow.contrib.lms import LMSHook
    # LMSHook and LMS share the same set of parameters. Here we just
    # use the default keyword arguments.
    lms_hook = LMSHook({'adam_optimizer'})
  3. Add the LMSHook into the Estimator hook list
    mnist_classifier.train(
          input_fn=train_input_fn,
          steps=20000
          hooks=[logging_hook, lms_hook])

For a working example of LMS integration with Estimator-based training, see /opt/DL/tensorflow/lib/python*/site-packages/tensorflow/contrib/lms/examples/cnn_mnist_lms.py, which is an LMS enabled version of

/opt/DL/tensorflow/lib/python*/site-packages/tensorflow/examples/tutorials/layers/cnn_mnist.py.

tf.keras-based training

To train a model that uses tf.keras-based training, follow these steps:

  1. Define the LMSKerasCallback
    from tensorflow.contrib.lms import LMSKerasCallback
    # LMSKerasCallback and LMS share a set of keyword arguments. Here we just
    # use the default options.
    lms_callback = LMSKerasCallback()
  2. Pass the callback to the Keras fit or fit_generator function.
    model.fit_generator(generator=training_gen, callbacks=[lms_callback])

Required Parameters for LMS or LMSHook

graph
The graph that is modified for LMS. This graph is the graph of user-defined neural network. (not required in LMSHook).
optimizer_scopes
Scopes for the optimizers or solvers.

Optional parameters for LMS or LMSHook

starting_scope
Tensors that are reachable from the operations in this scope are swapped for LMS. Set this option to the scope of the first layer if you want to modify the whole graph. Default None.
starting_op_names
Tensors that are reachable from the operations with these names are swapped for LMS. Default None.
excl_scopes
A set of scopes for operations whose tensors are not swapped out to the host. Default empty.
incl_scopes
A set of scopes for operations whose tensors are swapped out to the host. Default empty.
excl_types
A set of types for operations whose tensors are not swapped out to the host. Default empty.
incl_types
A set of types for operations whose tensors are swapped out to the host. Default empty.
n_tensors
The number of tensors for LMS, counting from the starting_scope. To turn off LMS, set n_tensors to 0. Default -1 (all reachable tensors are swapped for LMS).
lb
Lower bound value for LMS. A tensor is swapped in during the backward phase at least lb nodes before it in the graph. Default 1.
ub
Upper bound value for LMS. Default 10000.
fuse_swapins
Fuse "close" swap-in operations into one operation. This action can improve the performance. Default False.
ctrld_strategy
Two strategies to find control dependency ops for swap in ops: chain_rule and direct_order. chain_rule strategy starts from a forward operation, goes forward, and finds a corresponding backward operation to be a control dependency operation. direct_order strategy directly gets a backward op in the topological order to be a control dependency operation. Both strategies depend on lb and ub to choose a control dependency operation. While the direct_order is more exact than chain_rule in relation to lb and ub, it experimentally often results in smaller maximum batch size than chain_rule. Default chain_rule.
swap_branches
If True, LMS swaps tensors in branches in the forward phase. Default False.
branch_threshold
If swap_branches is enabled and the topological-sort distance between the consuming operation and generating operation of a tensor is greater than branch_threshold, then swap the tensor. Default 0.
debug
Debug mode for LMS. Default False.
debug_level
Debug level for LMS (1 or 2). Default 1.

Scaling tips

TensorFlow sets a limit on the amount of memory that is allocated on the CUDA host (CPU) side. The limit is often not high enough to act as a tensor swap space when you are swapping a large amount of data or when you are using multiple GPUs in a multi-tower fashion with a tower for each GPU as described in the TensorFlow documentation. Failure to set this limit higher results in out of memory errors like this error: Allocator (cuda_host_bfc) ran out of memory trying to allocate. Note the cuda_host_bfc allocator is mentioned rather than a GPU allocator.

A good practice would be to start with a value that is four times the memory capacity of the GPUs times the number of GPUs that is used. For example, if you have four 16-GB GPUs in a system and uses all four in a training run, TF_CUDA_HOST_MEM_LIMIT_IN_MB can be set to 262144 and adjust from there as needed. (4 x 16384 (16 GB as MB) x 4 GPUs) = 262144 MB.

Multi GPU scaling tips

To achieve better scaling performance with LMS on multiple GPUs, update the training script to use PowerAI Distributed Deep Learning and run the training script with the ddlrun command. Additionally, if you are running on a single system without an InfiniBand set up, the --mpiarg -pami_noib parameter must be added to the ddlrun command line, for example:

ddlrun --mpiarg -pami_noib -H host1 python train_model.py

For more information about using ddlrun, see /opt/DL/ddl/doc/README.md and /opt/DL/ddl-tensorflow/doc/README.md. You can also find information in this blog: Improved Ease of Use for DDL in PowerAI

Performance tuning LMS

After you enable LMS graph modification in your code, find the combination of tuning parameters that gives the fastest training time and best accuracy with your model. The goal of the performance tuning is to swap out enough tensors to allow your training to run without causing out of memory errors while not swapping too many so that the extra swapping communication overhead degrades performance.

The two tuning parameters to focus on are n_tensors and lb. Since n_tensors controls the number of tensors that are swapped, the higher this n_tensors is set, the lower the peak GPU memory usage is. The lb controls how soon the tensor is swapped back in before use. A low value of lb can make the training on the GPU pause and wait while the swap in finishes and degrades performance. A higher value of lb allows the tensor swap in to finish before it is needed and allows training to run without pause. The downside to swapping in too early is that more tensors are than in GPU memory at any point in time, resulting in higher peak GPU memory usage.

The tuning thus becomes finding the correct balance between n_tensors and lb that provides the best performance for given model. To start the performance tuning, consider setting n_tensors to -1, which swaps all reachable tensors. Set lb to the default 1, which is the last possible swap in. If tf.logging verbosity is set to tf.logging.INFO, LMS outputs a log statement with a count of the number of tensors swapped. It is useful to run with n_tensors=-1 for the first run to find this maximum value and then adjust it downward. If your model has branches like some UNet models do, consider setting swap_branches=True and tuning the branch threshold as well. While you tune the parameters, it is often useful to surface the LMS parameters in the training script as command-line parameters or configuration file properties.

By default LMS analyzes your graph to find the starting operations to use for finding tensor swap candidates. You can bypass this analysis by placing your starting operations in a named scope and providing the scope on the starting_scope parameter, or by providing the names of the starting operations on the starting_op_names parameter. This action can speed up repeated runs of LMS during tuning. Furthermore, you can enable debug=True and debug_level=1 and LMS prints the name and type of the starting operations that it finds. These names can be passed in on the starting_op_names parameter on subsequent runs.

It is recommended that you start with tuning training on a single GPU before you enable your code for multi-GPU with DDL.

LMS with saved models

Both TensorFlow and Keras have various ways to save models. Some of these methods save the model or graph definition and some methods save only the weights. Whether you need to enable large model support on the loaded model depends on several factors, such as if you are loading the model for further training or loading the model for further inferencing, as well as how the model was saved.

If TensorFlow MetaGraphs or SavedModels are saved after LMS adds swapping nodes to the model, the loaded model will contain swapping nodes. If only the model weights are saved, and are restored onto a model that is built using code, then the model has only LMS swapping nodes if LMS is run again on the model.

Keras models that are saved with tf.keras.models.save_model do not have LMS swapping nodes in them. If swapping is required in the loaded model, the LMSKerasCallback can be passed to the load tf.keras.models.load_model API. For example:

from tensorflow.contrib.lms import LMSKerasCallback
lms_callback = LMSKerasCallback()
model = tf.keras.models.load_model(filepath, callbacks=[lms_callback])