PyTorch 2.0 Release: Summary of Updates and Improvements

tl;dr

PyTorch 2.0 release has the same user experience as its predecessor while making fundamental changes and enhancing PyTorch's performances.

Introduction

PyTorch released its next-generation version, PyTorch 2.0, during the PyTorch Conference on 12/2/22.

This release has the same user experience as its predecessor while making fundamental changes and enhancing PyTorch's operations at the compiler level for faster performance and better support for dynamic shapes and distributed computing.

Stable Features

  • Accelerated PyTorch 2 Transformers: PyTorch 2.0 includes a new high-performance implementation of the PyTorch Transformer API, known as Accelerated PT2 Transformers, which aims to make training and deployment of state-of-the-art Transformer models affordable across the industry. It introduces high-performance support for training and inference using a custom kernel architecture for scaled dot product attention (SPDA). Accelerated PyTorch 2 Transformers are integrated with torch.compile() (feature in beta) to use your model while benefiting from the additional acceleration of PT2-compilation for inference or training.
  • Multiple Scaled Dot Product Attention (SDPA) Custom Kernels: To take advantage of different hardware models and Transformer use cases, multiple SDPA custom kernels are supported with custom kernel selection logic that picks the highest-performance kernel for a given model and hardware type.

Beta Features

  • Torch.compile: The torch.compile API is a beta feature in PyTorch 2.0. It is used to compile models for high-performance inference and training. The API wraps your model and returns a compiled model that can run faster than the original uncompiled model.
  • Scaled dot product attention: The scaled_dot_product_attention function is part of torch.nn.functional in PyTorch 2.0. It is used to implement attention mechanisms in Transformer models.
  • MPS backend: The Metal Performance Shaders (MPS) backend provides GPU-accelerated PyTorch training on Mac platforms. It brings coverage to over 300 operators and adds support for the Top 60 most used ops.

PyTorch 2.0 also introduces new prototype features and technologies across TensorParallel, DTensor, 2D parallel, TorchDynamo, AOTAutograd, PrimTorch, and TorchInductor.

In addition to the above features, PyTorch 2.0 includes several other updates and improvements across various inferences, performance, and training optimization features on GPUs and CPUs.

References