Facilitating the Spread of Knowledge and Innovation in Professional Software Development

Write for InfoQ


Choose your language

InfoQ Homepage News TensorFlow DTensor: Unified API for Distributed Deep Network Training

TensorFlow DTensor: Unified API for Distributed Deep Network Training

This item in japanese

Recently released TensorFlow v2.9 introduces a new API for the model, data, and space-parallel (aka spatially tiled) deep network training. DTensor aims to decouple sharding directives from the model code by providing higher-level utilities to partition the model and batch parameters between devices. The work is part of the recent effort (e.g. GPipe, TF Mesh, GShard, DeepSpeed, Fairscale, ColossalAI) to decrease development time to build large-scale training workloads.

Training test loss scales logarithmically with the number of network parameters, data size, and compute time for large (language) models. Hence, task-level improvements heavily depended on deep network size in recent years. The utilization of an ever-increasing number of acceleration devices has also mandated a large amount of engineering work due to distributed nature of such training platforms (i.e. GPUs, and FPGAs come with a limited in-device memory unit). DTensor provides an alternative abstraction on top of training units (i.e. device agnostic) by providing mesh configuration for devices and layout for tensor placement.

As an example, the below illustration shows the placement of a tensor on two different mesh configurations with three different layouts. In the second mesh case, it is possible to opt for a column or row-wise sharding by indicating which dimension should be "unsharded" and instead duplicated between devices.

Fig-1: Ilustration showing placement of a tensor on two different mesh configurations.

By creating separate mesh and layout objects, DTensor brings flexibility to adopt different training topologies without hardcoding device configurations. As an example, it offers a straightforward way to implement spatial partitioning of tensors along any dimension without using specialized API for computer vision applications (unlike TensorFlow TPUEstimator spatial partitioning). It should be noted that the device mesh API can be used with TF virtual devices (via logical device mechanism), therefore different sharding scenarios can be experimented with using the DTensor API.

Although DTensor has an experimental API, it currently supports direct Keras integration. In the code snippet below, an example of fully replicated weight layouts are provided to a dense layer:

mesh = dtensor.create_mesh([("batch", 8)], devices=devices)
kernel_layout = dtensor.Layout.replicated(mesh, 2)
bias_layout = dtensor.Layout.replicated(mesh, 1)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),

DTensor provides a drop-in replacement for most of the tensor operations, hence it can be used by tf.function and tf.GradientTape APIs as well. However, the current TensorFlow version does not support the built-in training loop for Keras, a custom one should be written for training DTensor sharded models. It also supports single and multi-client training jobs, multiple processes can exploit the API natively.

Additional information can be obtained from DTensor overview documentation. TensorFlow website also provides examples for low-level distributed training and Keras training.

About the Author

Rate this Article