融合算子
融合算子本质上是:把计算图中多个连续的小算子合并成一个更大的算子或一个 GPU kernel,让它们一次性执行,减少中间结果写回显存、减少 kernel launch、减少 layout 转换,从而提高推理性能。
在 TensorRT、TVM、XLA、TorchInductor、ONNX Runtime、TensorRT-LLM 里,算子融合都是核心优化。
1. 为什么要融合算子
GPU 推理不只是“算力瓶颈”,很多时候是:
1 | 读显存 → 算一点 → 写显存 → 再读显存 → 再算一点 → 再写显存 |
例如:
1 | Conv → BatchNorm → ReLU |
如果不融合,大概是:
1 | kernel 1: Conv |
问题有两个:
第一,中间 tensor 反复读写显存。
第二,每个 kernel 都有 launch overhead。
如果融合成:
1 | FusedConvBNReLU |
就可以变成:
1 | kernel: Conv + BN + ReLU |
这通常比单纯减少 FLOPs 更重要,因为现代 GPU 的 Tensor Core 算力很强,很多小算子瓶颈在 memory bandwidth 和 kernel launch,而不是纯计算。
2. 融合算子的几种类型
融合不是一种东西,而是一组不同层级的优化。
2.1 Weight folding:权重折叠
最典型的是:
1 | Conv + BatchNorm |
推理阶段 BatchNorm 的均值、方差、scale、bias 都是固定的,所以可以提前合并进 Conv 的 weight 和 bias。
原始 Conv:
$$
y = W * x + b
$$
BatchNorm:
$$
z = \gamma \frac{y - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
$$
代入 Conv 后:
$$
z =
\frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} (W * x + b - \mu) + \beta
$$
可以重新定义:
$$
W’ = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} W
$$
$$
b’ = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}}(b - \mu) + \beta
$$
于是:
$$
z = W’ * x + b’
$$
这样 BatchNorm 这个算子在推理图中可以直接消失。
这个融合的特点是:不需要新 kernel,只是改权重。
它是最稳定、最常见的融合之一。
2.2 Elementwise fusion:逐元素算子融合
例如:
1 | Add → ReLU |
逐元素算子通常计算很轻,但会产生大量显存读写。比如:
1 | y = x + bias |
不融合时:
1 | 读 x |
融合后:
1 | 读 x |
也就是把:
1 | z = gelu(x + bias) |
放到一个 kernel 里完成。
Transformer 里非常常见:
1 | GEMM → BiasAdd → GELU |
尤其是 MLP 部分:
1 | Linear → Bias → GELU → Linear |
常见融合是:
1 | Linear + Bias + GELU |
2.3 Epilogue fusion:矩阵乘法尾部融合
这个在 LLM 里很关键。
GEMM 本身通常由 cuBLASLt、CUTLASS 或 TensorRT tactic 执行。GEMM 计算完后,很多时候会立刻做:
1 | bias add |
所谓 epilogue fusion,就是把这些操作塞进 GEMM 的输出阶段。
例如普通路径:
1 | C = A @ B |
融合路径:
1 | E = GELU(A @ B + bias) |
底层不是先把 C 写回显存,再启动一个 kernel 做 bias/GELU,而是在 GEMM accumulator 还在寄存器里时直接完成后处理,然后只写最终结果。
这种融合对 LLM 特别重要,因为 Transformer 中大量时间花在 GEMM 和 GEMM 后面的轻量算子上。
2.4 Vertical fusion:纵向融合
纵向融合指把一条链上的连续算子融合:
1 | A → B → C → D |
融合成:
1 | Fused(A,B,C,D) |
典型例子:
1 | Conv → BatchNorm → ReLU |
它主要减少中间 tensor。
2.5 Horizontal fusion:横向融合
横向融合指多个并行的小算子合并。
例如 Transformer 的 Q、K、V 投影:
1 | Q = X Wq |
不融合时是三个 GEMM:
1 | MatMul(X, Wq) |
可以融合成一个大 GEMM:
1 | [Q, K, V] = X [Wq, Wk, Wv] |
也就是把三个权重矩阵在输出维度拼起来:
1 | Wqkv = concat(Wq, Wk, Wv) |
然后一次性算:
1 | QKV = X @ Wqkv |
这种融合会提高 GEMM 尺寸,提升 GPU 利用率,同时减少 kernel 数量。
LLM 里这很常见:
1 | Q projection |
融合成:
1 | QKV projection |
3. TensorRT 里常见的融合模式
TensorRT 会在构建 engine 时扫描计算图,识别固定 pattern,然后进行融合。
常见有:
1 | Conv + Bias + Activation |
对于 CNN:
1 | Conv → BN → ReLU |
对于 Transformer:
1 | QKV projection fusion |
对于 TensorRT-LLM,还会有更专门的融合:
1 | RMSNorm fusion |
4. 融合到底省了什么
可以用一个简单例子看。
假设有三个逐元素算子:
1 | y1 = x + a |
假设 tensor 有 $N$ 个元素。
不融合时大概显存访问:
1 | Add: |
忽略 cache,大约是:
1 | 读写次数 ≈ 8N |
融合后:
1 | y3 = relu((x + a) * b) |
显存访问变成:
1 | read x, read a, read b, write y3 |
大约:
1 | 读写次数 ≈ 4N |
中间的 y1、y2 不再落到显存。
在 GPU 上,很多 activation、normalization、elementwise 操作本身 FLOPs 很少,但是 tensor 很大,所以显存读写是主要开销。融合后性能提升通常来自减少 memory traffic,而不是减少数学运算。
5. Conv + BN + ReLU 的融合细节
这是最经典的融合。
训练图:
1 | Conv → BatchNorm → ReLU |
推理阶段可以分两步。
第一步,BN 折叠到 Conv 权重:
1 | Conv + BN → Conv' |
第二步,ReLU 作为 Conv 的 activation epilogue:
1 | Conv' + ReLU → FusedConvReLU |
最终:
1 | Conv → BN → ReLU |
变成:
1 | FusedConvReLU |
这种融合非常适合推理,因为 BatchNorm 的统计量固定。训练时通常不能这么简单融合,因为 BN 的 mean/variance 依赖当前 batch,并且要反向传播。
6. LayerNorm 为什么也适合融合
LayerNorm 的计算大概是:
$$
\mu = \frac{1}{H}\sum_{i=1}^{H} x_i
$$
$$
\sigma^2 = \frac{1}{H}\sum_{i=1}^{H}(x_i - \mu)^2
$$
$$
y_i = \gamma_i \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta_i
$$
朴素实现可能拆成很多 kernel:
1 | reduce mean |
融合后可以用一个或少数几个 kernel 完成:
1 | FusedLayerNorm |
好处是:
1 | 1. 减少多次读写 hidden states |
LLM 里每层都有 LayerNorm 或 RMSNorm,所以这类融合收益很明显。
RMSNorm 更简单:
$$
y_i = \gamma_i \frac{x_i}{\sqrt{\frac{1}{H}\sum_j x_j^2 + \epsilon}}
$$
因为没有减均值,计算更轻,更容易融合。
7. Attention fusion
标准 attention 逻辑大致是:
1 | Q = XWq |
不融合时会产生多个中间矩阵:
1 | Q, K, V |
问题是 attention score 的形状通常很大:
1 | batch × heads × seq_len × seq_len |
如果 seq_len 很长,中间矩阵显存开销巨大。
融合 attention 的目标是:
1 | 不把完整 S 和 P 长期写回显存 |
FlashAttention 这类思想就是典型代表:把 attention 分块,在 SRAM/shared memory/register 中完成局部计算,避免完整 materialize attention matrix。
TensorRT-LLM 的 attention fusion、paged KV cache attention、本质上也围绕这个方向优化:
1 | 减少 HBM 读写 |
8. 融合和量化的关系
量化模型里常见结构:
1 | Quantize → Dequantize → MatMul → Quantize → Dequantize → Add |
如果不优化,Q/DQ 节点会很多。
TensorRT 会尝试识别 Q/DQ pattern,把量化 scale 融入 kernel 中。例如:
1 | INT8 input |
可以融合成一个 quantized GEMM kernel:
1 | Dequant + MatMul + Scale + Activation + Quant |
也就是说,量化不是简单把 tensor 变小,还需要把 scale、zero point、dequant、requant 等操作融合进主算子,否则额外开销会吃掉收益。
9. 融合不是越多越好
有些融合会失败,或者不值得融合。
常见原因:
1 | 1. 中间结果被多个下游节点复用 |
例如:
1 | A → B → C |
如果 B 的输出同时给 C 和 D,那 A+B+C 的融合可能会破坏 D 的输入复用,TensorRT 需要判断是否值得。
还有一种情况是融合后 kernel 太复杂:
1 | 更多寄存器 |
所以编译器通常有 cost model 或 tactic benchmarking,不是机械地把所有东西都融合。
10. TensorRT 是怎么决定融合的
TensorRT 大概经历这些阶段:
1 | 1. 解析 ONNX / Network Definition |
例如图里有:
1 | Conv |
TensorRT 可以重写成:
1 | Convolution layer with fused scale and activation |
而对于 Transformer:
1 | MatMul |
可能变成:
1 | fused GEMM epilogue |
如果内置融合不支持,就可能需要 plugin。
11. Python backend 不会自动做这种融合
如果用 Triton Python backend 包 FireRedASR2-LLM,那么:
1 | Triton Python backend |
它本身不会把:
1 | Linear → Bias → GELU |
自动融合成 TensorRT kernel。
真正发生融合的地方可能是:
1 | PyTorch eager: 少量库内融合,比如 cuDNN/cuBLAS 内部 |
所以如果目标是融合算子加速,通常需要把模型或模型子图交给一个 compiler/runtime,而不是只放进 Python backend。
12. 对 FireRedASR2-LLM 的实际意义
的模型大概率包括:
1 | 音频预处理 |
可以分成几块看:
ASR encoder
如果是标准 Transformer/Conformer encoder,可能有:
1 | Conv/Subsampling |
可优化点:
1 | LayerNorm/RMSNorm fusion |
Qwen2-7B-Instruct
LLM 部分更适合 TensorRT-LLM 优化:
1 | RMSNorm fusion |
Python 预处理/后处理
这些部分不容易被 TensorRT 自动融合。比如:
1 | kaldiio 读 cmvn |
更适合做工程优化,例如缓存、批处理、异步流水线。
13. 如何判断融合是否有效
可以看三个指标:
1 | 1. kernel 数量是否减少 |
常用工具:
1 | trtexec verbose 日志 |
直观判断:
如果优化前看到很多小 kernel:
1 | add_kernel |
优化后变成少数 fused kernel:
1 | fused_bias_gelu |
一般说明融合生效了。
14. 一个简化例子
未融合:
1 | Input |
推理优化后:
1 | Input |
如果第二个 MatMul 后面还有 residual add:
1 | MatMul |
可能继续变成:
1 | FusedResidualLayerNorm |
最终 Transformer block 里的许多小算子会被压缩成几个大 kernel。
15. 总结
融合算子的核心不是“数学上少算了多少”,而是:
1 | 减少 kernel launch |
对于 CNN,典型是:
1 | Conv + BN + ReLU |
对于 Transformer/LLM,典型是:
1 | QKV projection fusion |
对当前 FireRedASR2-LLM 来说,Python backend 只是服务化;真正想靠融合提速,需要进一步考虑 torch.compile、ONNX/TensorRT、TensorRT-LLM,或者把 ASR encoder 和 Qwen2 LLM 部分拆出来分别优化。
