About this image ...
1659 words
8 minutes
GPU架构与Triton算子编写
🚀 GPU 架构与 Triton 算子开发速查手册
一、 Triton 编程模型:SPMD 与 pid
Triton 采用 SPMD (Single Program, Multiple Data) 模型。开发者编写的是宏观的 Block (块) 级别代码,屏蔽了底层 Thread (线程) 的繁琐同步。
1. 核心抽象对应关系
| Triton 概念 | CUDA / 硬件概念 | 调度位置 | 说明 |
|---|---|---|---|
| Grid (网格) | Grid | 整个 GPU | 算子启动的总任务池,包含成百上千个 Program。 |
Program (pid) | Thread Block (线程块) | 单个 SM | Triton 编程的第一视角。一个 Program 会被整体分配到一个 SM 上执行。注意一个 SM 可以装进多个 Program。 |
| (被隐藏) | Warp (线程束) | SM 内部 | 32个物理线程,SM 调度和延迟隐藏的最小单位。 |
| (被隐藏) | Thread (线程) | ALU (计算核心) | 真正执行标量计算的微观单元。 |
2. 算子基本骨架 (pid, offsets, mask)
@triton.jitdef add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): # 1. 拿到当前 Block 的编号 (相当于发给当前 SM 的工单号) pid = tl.program_id(axis=0)
# 2. 计算当前 Block 需要处理的数据全局索引 (Offsets) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 3. 内存越界保护 (Mask) - 极其重要!防止最后一个 Block 读写溢出 mask = offsets < n_elements
# 4. 加载、计算、存储 x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) tl.store(out_ptr + offsets, x + y, mask=mask)二、 Pytorch 到 GPU 的工程发射塔
在宿主机 (CPU) 调用 GPU 算子时,需要注意三个核心工程机制:
1. 边界对齐:triton.cdiv
- 用途:计算需要启动多少个 Block (Grid 的大小)。
- 逻辑:
triton.cdiv(n_elements, BLOCK_SIZE)等价于向上取整除法 。必须配合 Kernel 内的mask使用,确保尾部不足一个BLOCK_SIZE的数据也能被安全处理。
2. 编译期常量:tl.constexpr
- 用途:标记像
BLOCK_SIZE这样决定了底层 SRAM (共享内存) 和寄存器分配大小的变量。 - 机制:Triton JIT 编译器必须在运行前知道这些变量的确切值,从而为特定的
BLOCK_SIZE编译出极致优化的机器码。如果传入不同的值,Triton 会重新触发编译。
3. CPU/GPU 异步执行 (Asynchronous Execution)
- 机制:
add_kernel[grid](...)只是 CPU 向 GPU 的 Command Queue 里扔了一个任务工单,CPU 会瞬间返回 (耗时几微秒)。此时返回的outputTensor 往往还没有被填入计算结果。 - 同步约束:如果需要评估算子真实耗时,必须在前后加上
torch.cuda.synchronize()强行让 CPU 等待 GPU 清空计算队列。
三、 GPU 内存墙与核心硬件指标 (关键数值 📊)
所有的算子优化,本质上都是在进行“如何把数据从极慢的 DRAM 搬到极快的 SRAM”的微操。
1. 内存层级速查表
| 内存层级 | 物理位置 | 作用域/共享范围 (🎯重点) | 容量量级 (典型值) | 访问延迟 (典型值) |
|---|---|---|---|---|
| DRAM (全局显存) | GPU 板载 | 全局所有 SM 共享 | 几十 ~ 上百 GB (如A100 80GB) | 🔴 400 ~ 800 周期 |
| L2 Cache | GPU 芯片内 | 全局所有 SM 共享 | 几十 MB (如A100 40MB) | 🟡 ~ 100 ~ 200 周期 |
| SRAM (Shared Memory) | SM 车间内部 | 仅限同一个 Block 内部的所有 Thread 共享 (Block 间绝对隔离) | ~ 100~200 KB / SM (A100: 164KB/SM) | 🟢 ~ 20 ~ 30 周期 |
| Registers (寄存器) | 计算核心旁 | 仅限单个 Thread 私有 | ~ 256 KB / SM | ⚡ ~ 1 ~ 4 周期 |
2. SRAM 的黄金法则 (Shared Memory)
- 作用域铁律:SRAM 是Block 级的资产。
tl.load的本质就是将 DRAM 中的数据协同搬运到了当前 Block 专属的 SRAM 工作台上 (即 Tiling 切块机制)。 - 容量红线:一个 SM 的 SRAM 通常不超过 200KB。如果你定义的
BLOCK_SIZE过大,导致单 Block 需要的 SRAM 超过了 SM 总容量,程序会直接崩溃 (OOM) 或发生寄存器溢出 (Spilling),导致性能暴跌。
四、 延迟隐藏 (Latency Hiding) 与调度机制 (🎯核心精髓)
GPU 不是靠“缓存 (Cache)”来解决 DRAM 读取慢的问题,而是靠海量任务的零开销瞬间切换来掩盖等待时间。
1. Block 调度器 (宏观):提供高 Occupancy (占用率)
- 目标:尽可能多地向同一个 SM 里塞入 Block。
- 机制:只要 SM 的 SRAM 和寄存器没被填满,全局调度器就会把多个 Block 同时派发给同一个 SM (驻留/Co-residency)。这为微观的 Warp 调度提供了充足的“弹药”。
2. Warp 调度器 (微观):在【整个 SM 尺度】下执行切换
- 目标:让 ALU (计算核心) 永远不闲着。
- 机制 (极度重要):Warp 调度器无视 Block 的边界。它在一个拥有所有驻留 Block 的“Warp 大池子”里巡逻。
- 如果 Block A 的所有 Warp 都在等显存读取 (卡死)。
- 调度器会瞬间 (0 周期) 把 ALU 交给 Block B 或 Block C 里数据已经就绪的 Warp 去做数学计算。
- 推论:为了实现完美的延迟隐藏,我们不需要把单个 Block 搞得无限大;我们只需要让 Block 大小适中,确保一个 SM 能同时吞下多个 Block,由 Warp 调度器在跨 Block 的尺度上缝合等待时间。
五、 工程实践:如何寻找最优 BLOCK_SIZE?
不要手动硬猜,直接使用 Triton 的自动调优 (Auto-Tuning) 装饰器。
@triton.autotune( configs=[ triton.Config({'BLOCK_SIZE': 64}, num_warps=2), triton.Config({'BLOCK_SIZE': 128}, num_warps=4), triton.Config({'BLOCK_SIZE': 256}, num_warps=4), triton.Config({'BLOCK_SIZE': 512}, num_warps=8), triton.Config({'BLOCK_SIZE': 1024}, num_warps=8), ], key=['n_elements'], # 输入规模改变时触发重新基准测试)@triton.jitdef add_kernel(...): ...调优博弈论:
BLOCK_SIZE太小 ❌:无法发挥显存合并访问 (Memory Coalescing) 的高带宽优势,且 Grid 调度开销大。BLOCK_SIZE太大 ❌:单 Block 榨干 SRAM,SM 只能装下 1 个 Block,导致 Warp 池子枯竭,Warp 调度器无法施展“延迟隐藏”,计算核心大量时间处于空转等待。autotune寻找的是 ✅:在“单次搬砖效率”和“车间高占用率 (Occupancy)”之间的黄金平衡点。
六、 进阶学习资源路线 (DL 开发者向)
在掌握上述底层硬件直觉后,如果需要进一步提升算子编写能力,建议查阅以下资源:
- 理论心智模型建立:Making Deep Learning Go Brrrr From First Principles (Horace He) —— 透彻理解 Roofline Model 和算术强度 (Arithmetic Intensity)。
- GPU 架构科普:Which GPU(s) to Get for Deep Learning (Tim Dettmers) —— 以选购显卡为引,深度拆解 Tensor Cores 和 Memory Bandwidth 的重要性。
- 实操视频课程:CUDA MODE (YouTube/GitHub) —— 为 PyTorch 开发者量身定制的从 Python 走向底层的硬核教程。
- 实战刷题:Triton Puzzles (Sasha Rush) —— 用填空题的方式让你亲手手撕 SRAM Tiling 和
pid偏移量计算。