Facilitating the Spread of Knowledge and Innovation in Professional Software Development

Write for InfoQ


Choose your language

InfoQ Homepage News PyTorch 2.0 Compiler Improves Model Training Speed

PyTorch 2.0 Compiler Improves Model Training Speed

This item in japanese

The PyTorch Foundation recently released PyTorch version 2.0, a 100% backward compatible update. The main API contribution of the release is a compile function for deep learning models, which speeds up training. Internal benchmarks on 163 open-source AI projects showed that the models ran on average 43% faster during training.

Plans for the 2.0 release were announced at the PyTorch Conference in December 2022. Besides the new compile function, the release also includes performance improvement for Transformer-based models, such as large language models and diffusion models, via a new implementation of scaled dot product attention (SDPA). Training on Apple silicon is accelerated via improved Metal Performance Shaders (MPS), now with 300 operations implemented in MPS. Besides the core release, the domain libraries, including TorchAudio, TorchVision, and TorchText, were updated with new beta features. Overall, the 2.0 release includes over 4,500 commits from 428 developers since the 1.13.1 release. According to the PyTorch Foundation blog,

We are excited to announce the release of PyTorch® 2.0 which we highlighted during the PyTorch Conference on 12/2/22! PyTorch 2.0 offers the same eager-mode development and user experience, while fundamentally changing and supercharging how PyTorch operates at compiler level under the hood with faster performance and support for Dynamic Shapes and Distributed.

In his keynote speech at the PyTorch Conference 2022, PyTorch co-creator Soumith Chintala pointed out that thanks to increases in GPU compute capacity, many existing PyTorch workloads are constrained by memory bandwidth or by PyTorch framework overhead. Previously the PyTorch team had addressed performance problems by writing some of their core components in C++; Chintala described PyTorch as "basically a C++ codebase," and said that he "hates" contributing to the C++ components.

The new compile feature is based on four underlying components written in Python:

  • TorchDynamo - performs graph acquisition by rewriting Python code representing deep learning models into blocks of computational graphs
  • AOTAutograd - performs "ahead of time" automatic differentiation for the backprop step
  • PrimTorch - canonicalizes the over 2k PyTorch operators down to a fixed set of around 250 primitive operators
  • TorchInductor - generates fast hardware-specific backend code for accelerators

To demonstrate the performance improvements and ease of use of the compile function, the PyTorch team identified 163 open-source deep learning projects to benchmark. These included implementations of a wide variety of tasks including computer vision, natural language processing, and reinforcement learning. The team made no changes to the code besides the one-line call to the compile function. This single change worked in 93% of the projects, and the compiled models ran 43% faster when trained on NVIDIA A100 GPUs.

In a Hacker News discussion about the release, one user noted:

A big lesson I learned from PyTorch vs other frameworks is that productivity trumps incremental performance improvement. Both Caffe and MXNet marketed themselves for being fast, yet apparently being faster here and here by some percentage simply didn't matter that much. On the other hand, once we make a system work and make it popular, the community will close the performance gap sooner than competitors expect. Another lesson is probably old but worth repeating: investment and professional polishing [matter] to open source projects.

The PyTorch code and version 2.0 release notes are available on GitHub.

About the Author

Rate this Article