1659 words
8 minutes
GPU架构与Triton算子编写

🚀 GPU 架构与 Triton 算子开发速查手册#

一、 Triton 编程模型:SPMD 与 pid#

Triton 采用 SPMD (Single Program, Multiple Data) 模型。开发者编写的是宏观的 Block (块) 级别代码,屏蔽了底层 Thread (线程) 的繁琐同步。

1. 核心抽象对应关系#

Triton 概念CUDA / 硬件概念调度位置说明
Grid (网格)Grid整个 GPU算子启动的总任务池,包含成百上千个 Program。
Program (pid)Thread Block (线程块)单个 SMTriton 编程的第一视角。一个 Program 会被整体分配到一个 SM 上执行。注意一个 SM 可以装进多个 Program。
(被隐藏)Warp (线程束)SM 内部32个物理线程,SM 调度和延迟隐藏的最小单位。
(被隐藏)Thread (线程)ALU (计算核心)真正执行标量计算的微观单元。

2. 算子基本骨架 (pid, offsets, mask)#

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# 1. 拿到当前 Block 的编号 (相当于发给当前 SM 的工单号)
pid = tl.program_id(axis=0)
# 2. 计算当前 Block 需要处理的数据全局索引 (Offsets)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 3. 内存越界保护 (Mask) - 极其重要!防止最后一个 Block 读写溢出
mask = offsets < n_elements
# 4. 加载、计算、存储
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
tl.store(out_ptr + offsets, x + y, mask=mask)

二、 Pytorch 到 GPU 的工程发射塔#

在宿主机 (CPU) 调用 GPU 算子时,需要注意三个核心工程机制:

1. 边界对齐:triton.cdiv#

  • 用途:计算需要启动多少个 Block (Grid 的大小)。
  • 逻辑triton.cdiv(n_elements, BLOCK_SIZE) 等价于向上取整除法 N/B\lceil N / B \rceil。必须配合 Kernel 内的 mask 使用,确保尾部不足一个 BLOCK_SIZE 的数据也能被安全处理。

2. 编译期常量:tl.constexpr#

  • 用途:标记像 BLOCK_SIZE 这样决定了底层 SRAM (共享内存) 和寄存器分配大小的变量。
  • 机制:Triton JIT 编译器必须在运行前知道这些变量的确切值,从而为特定的 BLOCK_SIZE 编译出极致优化的机器码。如果传入不同的值,Triton 会重新触发编译。

3. CPU/GPU 异步执行 (Asynchronous Execution)#

  • 机制add_kernel[grid](...) 只是 CPU 向 GPU 的 Command Queue 里扔了一个任务工单,CPU 会瞬间返回 (耗时几微秒)。此时返回的 output Tensor 往往还没有被填入计算结果。
  • 同步约束:如果需要评估算子真实耗时,必须在前后加上 torch.cuda.synchronize() 强行让 CPU 等待 GPU 清空计算队列。

三、 GPU 内存墙与核心硬件指标 (关键数值 📊)#

所有的算子优化,本质上都是在进行“如何把数据从极慢的 DRAM 搬到极快的 SRAM”的微操。

1. 内存层级速查表#

内存层级物理位置作用域/共享范围 (🎯重点)容量量级 (典型值)访问延迟 (典型值)
DRAM (全局显存)GPU 板载全局所有 SM 共享几十 ~ 上百 GB (如A100 80GB)🔴 400 ~ 800 周期
L2 CacheGPU 芯片内全局所有 SM 共享几十 MB (如A100 40MB)🟡 ~ 100 ~ 200 周期
SRAM (Shared Memory)SM 车间内部仅限同一个 Block 内部的所有 Thread 共享 (Block 间绝对隔离)~ 100~200 KB / SM
(A100: 164KB/SM)
🟢 ~ 20 ~ 30 周期
Registers (寄存器)计算核心旁仅限单个 Thread 私有~ 256 KB / SM⚡ ~ 1 ~ 4 周期

2. SRAM 的黄金法则 (Shared Memory)#

  • 作用域铁律:SRAM 是Block 级的资产。tl.load 的本质就是将 DRAM 中的数据协同搬运到了当前 Block 专属的 SRAM 工作台上 (即 Tiling 切块机制)。
  • 容量红线:一个 SM 的 SRAM 通常不超过 200KB。如果你定义的 BLOCK_SIZE 过大,导致单 Block 需要的 SRAM 超过了 SM 总容量,程序会直接崩溃 (OOM) 或发生寄存器溢出 (Spilling),导致性能暴跌。

四、 延迟隐藏 (Latency Hiding) 与调度机制 (🎯核心精髓)#

GPU 不是靠“缓存 (Cache)”来解决 DRAM 读取慢的问题,而是靠海量任务的零开销瞬间切换来掩盖等待时间。

1. Block 调度器 (宏观):提供高 Occupancy (占用率)#

  • 目标:尽可能多地向同一个 SM 里塞入 Block。
  • 机制:只要 SM 的 SRAM 和寄存器没被填满,全局调度器就会把多个 Block 同时派发给同一个 SM (驻留/Co-residency)。这为微观的 Warp 调度提供了充足的“弹药”。

2. Warp 调度器 (微观):在【整个 SM 尺度】下执行切换#

  • 目标:让 ALU (计算核心) 永远不闲着。
  • 机制 (极度重要):Warp 调度器无视 Block 的边界。它在一个拥有所有驻留 Block 的“Warp 大池子”里巡逻。
    • 如果 Block A 的所有 Warp 都在等显存读取 (卡死)。
    • 调度器会瞬间 (0 周期) 把 ALU 交给 Block B 或 Block C 里数据已经就绪的 Warp 去做数学计算。
  • 推论:为了实现完美的延迟隐藏,我们不需要把单个 Block 搞得无限大;我们只需要让 Block 大小适中,确保一个 SM 能同时吞下多个 Block,由 Warp 调度器在跨 Block 的尺度上缝合等待时间。

五、 工程实践:如何寻找最优 BLOCK_SIZE#

不要手动硬猜,直接使用 Triton 的自动调优 (Auto-Tuning) 装饰器。

@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64}, num_warps=2),
triton.Config({'BLOCK_SIZE': 128}, num_warps=4),
triton.Config({'BLOCK_SIZE': 256}, num_warps=4),
triton.Config({'BLOCK_SIZE': 512}, num_warps=8),
triton.Config({'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['n_elements'], # 输入规模改变时触发重新基准测试
)
@triton.jit
def add_kernel(...):
...

调优博弈论

  • BLOCK_SIZE 太小 ❌:无法发挥显存合并访问 (Memory Coalescing) 的高带宽优势,且 Grid 调度开销大。
  • BLOCK_SIZE 太大 ❌:单 Block 榨干 SRAM,SM 只能装下 1 个 Block,导致 Warp 池子枯竭,Warp 调度器无法施展“延迟隐藏”,计算核心大量时间处于空转等待。
  • autotune 寻找的是 ✅:在“单次搬砖效率”和“车间高占用率 (Occupancy)”之间的黄金平衡点。

六、 进阶学习资源路线 (DL 开发者向)#

在掌握上述底层硬件直觉后,如果需要进一步提升算子编写能力,建议查阅以下资源:

  1. 理论心智模型建立Making Deep Learning Go Brrrr From First Principles (Horace He) —— 透彻理解 Roofline Model 和算术强度 (Arithmetic Intensity)。
  2. GPU 架构科普Which GPU(s) to Get for Deep Learning (Tim Dettmers) —— 以选购显卡为引,深度拆解 Tensor Cores 和 Memory Bandwidth 的重要性。
  3. 实操视频课程CUDA MODE (YouTube/GitHub) —— 为 PyTorch 开发者量身定制的从 Python 走向底层的硬核教程。
  4. 实战刷题Triton Puzzles (Sasha Rush) —— 用填空题的方式让你亲手手撕 SRAM Tiling 和 pid 偏移量计算。