

# Towards Agile Development of Efficient Deep Learning Operators

Keren Zhou & Philippe Tillet

### **Deep Neural Networks (DNNs)**





#### **Computer Vision**



#### **Recommendation Systems**



Speech Recognition

```
a = torch.randn(64, 32)
b = torch.randn(32, 64)
c = torch.randn(64, 64)
d = torch.mm(a, b)
e = c + d
```

| Model                                                    | Graph                                                          | Kernel                                            | Device                                            |
|----------------------------------------------------------|----------------------------------------------------------------|---------------------------------------------------|---------------------------------------------------|
| <ul><li>PyTorch</li><li>TensorFlow</li><li>JAX</li></ul> | <ul><li>XLA/HLO</li><li>TVM/Relay</li><li>PyTorch/fx</li></ul> | <ul><li>CUDA</li><li>HIP</li><li>OpenCL</li></ul> | <ul><li> GPU</li><li> CPU</li><li> FPGA</li></ul> |





|                                                          | е                                                              |                                      |                                                   |
|----------------------------------------------------------|----------------------------------------------------------------|--------------------------------------|---------------------------------------------------|
| Model                                                    | Graph                                                          |                                      | Device                                            |
| <ul><li>PyTorch</li><li>TensorFlow</li><li>JAX</li></ul> | <ul><li>XLA/HLO</li><li>TVM/Relay</li><li>PyTorch/fx</li></ul> | <ul><li>HIP</li><li>OpenCL</li></ul> | <ul><li> GPU</li><li> CPU</li><li> FPGA</li></ul> |





```
__global__
void mm(float *a, float *b,
float *c) {
    float *a_tile;
    float *b_tile;
    ...
}
```

| • | PyTorch |
|---|---------|
| • | TensorF |
| • | JAX     |
|   |         |

low

Model

| _ | VI | _A   | /Ш    |  |
|---|----|------|-------|--|
| • | ΛL | -/AV | 4 - 1 |  |
|   |    |      |       |  |

Graph

- TVM/Relay
- PyTorch/fx

• CUDA

Kernel

- HIP
- OpenCL

• GPU

Device

- CPU
- FPGA





```
__global__
void mm(float *a, float *b,
float *c) {
    float *a_tile;
    float *b_tile;
    ...
}
```



| • | Py | Τοι | rch |
|---|----|-----|-----|

Model

TensorFlow

JAX

• XLA/HLO

Graph

- TVM/Relay
- PyTorch/fx

• CUDA

Kernel

- HIP
- OpenCL

• GPU

Device

- CPU
- FPGA





```
__global__
void mm(float *a, float *b,
float *c) {
    float *a_tile;
    float *b_tile;
    ...
}
```



| Model        | Graph |
|--------------|-------|
| • PyTorch    | • )   |
| • TensorFlow | • 1   |
| • JAX        | • F   |
|              |       |
|              |       |

| XLA/HL | Ī |
|--------|---|
|        |   |

- TVM/Relay
- PyTorch/fx

- CUDA
- HIP

Kernel

OpenCL

• GPU

Device

- CPU
- FPGA

### A Large Number of Tensor Operators

→ Linear Convolution → Normalization → Embedding ◆ Fused Depthwise Batch Dilated Attention Layer Bilinear Transposed Sparse → Pooling → Loss → Recurrent SDDMM ◆ NLL ◆ LSTM Max/Min/Avg SPMM Adaptive BCE ♦ GRU

- TensorFlow: > 400 operators
- PyTorch: > 200 operators

### Various Data Types

- → Common tensor data types
  - ◆ Float64
  - ◆ Float32
  - ◆ Float32
  - ◆ Float16
  - ♦ BFloat16
  - ◆ Int64
  - ◆ Int32
  - ◆ Int16
  - ♦ Int8
  - Bool

For performance critical kernels: #Implementations ≈ #Data types × #Kernels

#### **Handwritten Code**

- → Low flexibility
  - ◆ Fine-tune for every shape/data type/algorithm
  - Employ assembly instructions
  - **.**..
- → **High** performance
  - Apply sophisticated instruction/operator scheduling
  - ♦ Simplify code
  - **•** ..

#### Handwritten Code is a Pain

- → For the company
  - ♦ Hard to recruit new Machine Learning Engineers
  - Difficult to maintain libraries
- → For the researchers
  - A black box
    - They want to understand how kernels work
    - They want to fast validate new ideas at scale

### Python-like Code

- → **High** flexibility
  - Build upon existing operators
  - ◆ No need to recompile
  - **...**
- → Low performance
  - Not fine-tuned for specific shapes
  - ◆ Intermediate memory movement
  - **♦** ..

Can we design a language to achieve both high performance and flexibility?

## **Triton**

A Programming Model for the Next Generation Deep Learning Systems

## **Programming Models for DNNs**



## **Programming Models for DNNs**



## Inefficiencies of Existing PyTorch V1 Operators

- → Individual kernels
  - Can be slow
  - ◆ Can run out-of-memory
- → Graph compiler
  - Don't support custom data-structures
    - lists/trees of tensors
    - block-sparse tensors
  - Don't support custom precision format
  - Automatic kernel fusion is limited

Solution: Employ Triton -> PyTorch V2

## Triton is Designed to Achieve Both High Flexibility and Performance

- → Flexibility
  - ◆ A small core set of operations (~40 interface functions and ~20 core functions)
  - ◆ Can be composed into almost all existing PyTorch operators (TorchInductor)
  - SPMD but not SIMT
- → Performance
  - ◆ JIT generated kernels
  - ◆ Handwritten PTX code
  - Many passes to combine, simplify, and schedule operations

#### **Triton Design**

- → PyTorch compatible
  - ◆ Tensors are stored on-chip rather than off-chip
  - Custom data-structures using tensors of pointers
- → Python syntax
  - All standard python control flow structure (for/if/while) are supported
  - Python code is lowered to Triton IR

## Write GPU Kernels Using Triton

#### **GPU-accelerated Application Overview**

- → CPU and GPU execute asynchronously
- → CPU dispatches commands to GPU



## **Terminologies**

- → Parallelism
  - ◆ Grid
    - One for each kernel
  - ◆ Block/Warp/Thread
- → Memory
  - ◆ Global
    - Visible to all threads
  - ◆ Shared
    - Private to each block
  - ◆ Local
    - Private to each thread

#### **CUDA vs Triton**

|                 | CUDA                 | Triton        |
|-----------------|----------------------|---------------|
| Memory          | Global/Shared/Local  | Automatic     |
| Parallelism     | Threads/Blocks/Warps | Mostly Blocks |
| Tensor Core     | Manual               | Automatic     |
| Vectorization   | .8/.16/.32/.64/.128  | Automatic     |
| Async SIMT      | Support              | Limited       |
| Device Function | Support              | Not Available |

Using Triton, you only need to know that a program is divided into multiple blocks

## **Vector Addition (Single Block)**

- → Z[:] = X[:] + Y[:]
  - Without boundary check

```
import triton.language as tl
import triton
```

```
N = 1024
x = torch.randn(N, device='cuda')
y = torch.randn(N, device='cuda')
z = torch.randn(N, device='cuda')
```

#### **Vector Addition (Boundary Check)**

- → Z[:] = X[:] + Y[:]
  - With boundary check
- → program\_id()
  - Get the block id
- → mask
  - if mask[idx] is false, do not load the data at address pointer[idx]
- → triton.cdiv(N, 1024)
  - $\bullet$  (N 1)//1024 + 1

```
@triton.jit
def add(z ptr, x ptr, y ptr, N):
    # same as torch.arange
    offsets = tl.arange(0, 1024)
    # create 1024 pointers to X, Y, Z
    x ptrs = x ptr + offsets
    y ptrs = y ptr + offsets
    z ptrs = z ptr + offsets
    # load 1024 elements of X, Y, Z
    # do computations
    z = x + y
    # write-back 1024 elements of X, Y, Z
N = 192311
x = torch.randn(N, device='cuda')
v = torch.randn(N, device='cuda')
z = torch.randn(N, device='cuda')
grid = (triton.cdiv(N, 1024), )
add[grid](z, x, y, N)
```

#### **Vector Addition (Custom Tile Size)**

- → Z[:] = X[:] + Y[:]
  - Each block computes TILE elements
- → @triton.autotune
  - Select the best config based on the execution time
  - We don't want to build complex autotune policies into Triton

```
@triton.jit
def add(z ptr, x ptr, y ptr, N):
    # same as torch.arange
    offsets = tl.arange(0, TILE)
    offsets += tl.program id(0)*TILE
    # create 1024 pointers to X, Y, Z
    x ptrs = x ptr + offsets
    y ptrs = y ptr + offsets
    z ptrs = z ptr + offsets
    # load 1024 elements of X, Y, Z
    x = tl.load(x ptrs, mask=offset<N)</pre>
    y = tl.load(y ptrs, mask=offset<N)</pre>
    # do computations
    z = x + y
    # write-back 1024 elements of X, Y, Z
    tl.store(z ptrs, z, mask=offset<N)
N = 192311
x = torch.randn(N, device='cuda')
y = torch.randn(N, device='cuda')
z = torch.randn(N, device='cuda')
```

# **Optimizing GPU Kernels**

#### **NVIDIA GA100 Architecture & Programming Challenges**

- → Multiple compute units
- → Multiple memory spaces
- → Multiple data types
- → Thread synchronization/divergence
- → Tensor cores



#### **Techniques for Optimizing a GEMM Kernel**

Vanilla (1-10% fp32 peak) NVIDIA CUDA Programming Guide (30%-50% fp32 peak) +global memory coalesce C++/C+shared memory CUTLASS (80%-90% tf32 peak) +vectorization +shared bank conflict reduction +thread layout autotune +async shared memory transfer C++ Template & PTX +multi-stage shared memory +tf32 tensor core cuBLAS (~90% tf32 peak) +register bank conflict reduction +control code optimization SASS

Difficulty

### **Utilizing Tensor Cores - Layout**

- → For each warp, we must load values into tiles of a specific layout to perform matrix multiplications
  - Each data type could have multiple layouts
  - ◆ Different data types (e.g., fp16 vs fp64) have different layouts



## **Utilizing Tensor Cores - Memory Swizzling**

- → Swizzling tiles (T) when loading from global memory to avoid bank conflicts
- → Simple padding do not work because we need to read multiple tiles on different rows

| Phase 0 | ТО   | T1 | T2 | Т3 | T4 | T5 |  |
|---------|------|----|----|----|----|----|--|
| Phase 1 | T0   | T1 | T2 | Т3 | T4 | T5 |  |
| Phase 2 | T0   | T1 | T2 | Т3 | T4 | T5 |  |
|         |      |    |    |    |    |    |  |
| Phase 0 | T0   | T1 | T2 | Т3 | T4 | T5 |  |
| Phase 1 | T1   | T0 | Т3 | T2 | T5 | T4 |  |
| Phase 2 | Tn-1 | Т3 | T0 | T5 | T2 | T7 |  |

## **Utilizing Tensor Cores - Idmatrix & stmatrix**

- → Each thread provides a pointer to 128b row of data in Shared Memory
- → A row is broadcast to four threads to match the arrangement of tensor cores

|      | Col0                                 | col1         | col2                   | col3                   | col4                   | col5                   | Col6                   | col7                  |  |
|------|--------------------------------------|--------------|------------------------|------------------------|------------------------|------------------------|------------------------|-----------------------|--|
| row0 | %lane<br>dst=                        | 17000        | %lane<br>dst=          | 733.                   |                        | eid = 2<br>=d0         | %lane<br>dst=          | 5345                  |  |
| row1 | %lane<br>dst=                        | 300.000      | %laneid = 5<br>dst=d0  |                        | 0.147500               | %laneid = 6<br>dst=d0  |                        | %laneid = 7<br>dst=d0 |  |
| row2 | %lane<br>dst=                        | 200000       | %laneid = 9<br>dst=d0  |                        | %laneid = 10<br>dst=d0 |                        | %laneid = 11<br>dst=d0 |                       |  |
| row3 | %laneid = 12 %laneid<br>dst=d0 dst=c |              | 74 777                 | %laneid = 14<br>dst=d0 |                        | %laneid = 15<br>dst=d0 |                        |                       |  |
| row4 | (100)                                |              | %lanei<br>dst=         | 700.7034               |                        | id = 18<br>:=d0        | %lanei<br>dst=         |                       |  |
| row5 | 1517                                 | %laneid = 20 |                        | 1000                   | id = 22<br>=d0         | %lanei<br>dst=         |                        |                       |  |
| row6 | %lanei<br>dst=                       | 2            | %laneid = 25<br>dst=d0 |                        | 700,770                |                        | %laneid = 27<br>dst=d0 |                       |  |
| row7 | %lanei<br>dst=                       | 50055        | %laneid = 29<br>dst=d0 |                        | %laneid = 30<br>dst=d0 |                        | %laneid = 31<br>dst=d0 |                       |  |

#### **Techniques for Optimizing a GEMM Kernel**



### **Element-wise Operators**

- → Triton and Torch both achieve peak bandwidth
- → Researchers can write fused element-wise operators easily using Triton



#### **Fused Softmax**

- → Triton kernels can keep data on-chip throughout the entire softmax
- → PyTorch JIT could in theory do that but in practice doesn't
- → The native PyTorch op is designed to work

  for every input shape and is slower in cases

  where we care



## **Matrix Multiplication**

- → It takes <25 lines of code to write a Triton kernel on par with cuBLAS
- → Arbitrary ops can be "fused" before/after the GEMM while the data is still on-chip, leading to large speedups over PyTorch



#### **Fused Attention (Flash Attention)**

- → From the author: Triton is easier to understand and experiment with than CUDA
- → Triton forward + backward is slightly slower than CUDA forward + backward





#### Kernl

- → Run PyTorch transformer models several times faster on GPU with a single line of code
- → The first OSS inference engine written in Triton





### **New Challenges With Hopper**

- Tensor Memory Accelerator (TMA)
  - Transfer large blocks of data between global memory and shared memory
- Distributed Shared Memory
  - Direct communication between shared memory on different SMs
- Thread Block Cluster
  - Cluster -> Grid -> Block -> Warp
- FP8 Data Types and Mode (Transformer Engine)
  - Native FP8 tensor core

# Triton-MLIR (Triton V2)

#### **Goals**

- → Make Triton more robust
- → Using existing infrastructure to avoid creating new wheels
- → Support more backends

#### **Features**

- → MLIR (Multi-level intermediate representation)
  - Triton dialect
  - ◆ TritonGPU dialect
- → Clean layout concepts
  - Distributed, Sliced, Blocked, Shared, DotOperand
  - Adopted by CUTLASS (CuTe)
- → Low overhead runtime
  - Cache and fetch kernels using efficient signatures
- → Debugging
  - triton.language.print
- → Profiler interface
  - Kernel launch hooks
  - Compilation hooks

## **Hierarchical Design**



## Multiple Frontends and Backends (In Progress)



## Contributors

Anthropic

Da Yan

Meta

Shintaro Iwasaki

Microsoft

Ian Bearman

NVIDIA

Dongdong Li, Qingyi Liu, Chunwei Yan, Jun Yang, Chenggang Zhao, Ben Zhang, Goostavz Zhu

## **Takeaways**

- → Triton is designed to achieve both high performance and flexibility
- → Triton V2 will be more robust than Triton V1
- → Triton will support more backends other than NVIDIA GPUs soon

## Thank You

Visit openai.com for more information.