融合算子本质上是:把计算图中多个连续的小算子合并成一个更大的算子或一个 GPU kernel,让它们一次性执行,减少中间结果写回显存、减少 kernel launch、减少 layout 转换,从而提高推理性能。

在 TensorRT、TVM、XLA、TorchInductor、ONNX Runtime、TensorRT-LLM 里,算子融合都是核心优化。


1. 为什么要融合算子

GPU 推理不只是“算力瓶颈”,很多时候是:

1
读显存 → 算一点 → 写显存 → 再读显存 → 再算一点 → 再写显存

例如:

1
Conv → BatchNorm → ReLU

如果不融合,大概是:

1
2
3
4
5
6
7
8
9
10
11
12
13
kernel 1: Conv
input 从显存读入
weight 从显存读入
conv 输出写回显存

kernel 2: BatchNorm
conv 输出再从显存读入
BN 参数读入
BN 输出写回显存

kernel 3: ReLU
BN 输出再从显存读入
ReLU 输出写回显存

问题有两个:

第一,中间 tensor 反复读写显存。
第二,每个 kernel 都有 launch overhead。

如果融合成:

1
FusedConvBNReLU

就可以变成:

1
2
3
4
kernel: Conv + BN + ReLU
input/weight 读入
在寄存器或 shared memory 中完成 BN 和 ReLU
最终输出只写一次

这通常比单纯减少 FLOPs 更重要,因为现代 GPU 的 Tensor Core 算力很强,很多小算子瓶颈在 memory bandwidth 和 kernel launch,而不是纯计算。


2. 融合算子的几种类型

融合不是一种东西,而是一组不同层级的优化。

2.1 Weight folding:权重折叠

最典型的是:

1
Conv + BatchNorm

推理阶段 BatchNorm 的均值、方差、scale、bias 都是固定的,所以可以提前合并进 Conv 的 weight 和 bias。

原始 Conv:

$$
y = W * x + b
$$

BatchNorm:

$$
z = \gamma \frac{y - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
$$

代入 Conv 后:

$$
z =
\frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} (W * x + b - \mu) + \beta
$$

可以重新定义:

$$
W’ = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} W
$$

$$
b’ = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}}(b - \mu) + \beta
$$

于是:

$$
z = W’ * x + b’
$$

这样 BatchNorm 这个算子在推理图中可以直接消失。

这个融合的特点是:不需要新 kernel,只是改权重
它是最稳定、最常见的融合之一。


2.2 Elementwise fusion:逐元素算子融合

例如:

1
2
3
4
Add → ReLU
Mul → Add → Sigmoid
BiasAdd → GELU
Residual Add → LayerNorm

逐元素算子通常计算很轻,但会产生大量显存读写。比如:

1
2
y = x + bias
z = gelu(y)

不融合时:

1
2
3
4
5
读 x
读 bias
写 y
读 y
写 z

融合后:

1
2
3
读 x
读 bias
直接写 z

也就是把:

1
z = gelu(x + bias)

放到一个 kernel 里完成。

Transformer 里非常常见:

1
2
3
GEMM → BiasAdd → GELU
GEMM → BiasAdd
ResidualAdd → LayerNorm

尤其是 MLP 部分:

1
Linear → Bias → GELU → Linear

常见融合是:

1
Linear + Bias + GELU

2.3 Epilogue fusion:矩阵乘法尾部融合

这个在 LLM 里很关键。

GEMM 本身通常由 cuBLASLt、CUTLASS 或 TensorRT tactic 执行。GEMM 计算完后,很多时候会立刻做:

1
2
3
4
5
bias add
activation
scaling
residual add
quantize / dequantize

所谓 epilogue fusion,就是把这些操作塞进 GEMM 的输出阶段。

例如普通路径:

1
2
3
C = A @ B
D = C + bias
E = GELU(D)

融合路径:

1
E = GELU(A @ B + bias)

底层不是先把 C 写回显存,再启动一个 kernel 做 bias/GELU,而是在 GEMM accumulator 还在寄存器里时直接完成后处理,然后只写最终结果。

这种融合对 LLM 特别重要,因为 Transformer 中大量时间花在 GEMM 和 GEMM 后面的轻量算子上。


2.4 Vertical fusion:纵向融合

纵向融合指把一条链上的连续算子融合:

1
A → B → C → D

融合成:

1
Fused(A,B,C,D)

典型例子:

1
2
3
Conv → BatchNorm → ReLU
MatMul → BiasAdd → GELU
Add → LayerNorm

它主要减少中间 tensor。


2.5 Horizontal fusion:横向融合

横向融合指多个并行的小算子合并。

例如 Transformer 的 Q、K、V 投影:

1
2
3
Q = X Wq
K = X Wk
V = X Wv

不融合时是三个 GEMM:

1
2
3
MatMul(X, Wq)
MatMul(X, Wk)
MatMul(X, Wv)

可以融合成一个大 GEMM:

1
[Q, K, V] = X [Wq, Wk, Wv]

也就是把三个权重矩阵在输出维度拼起来:

1
Wqkv = concat(Wq, Wk, Wv)

然后一次性算:

1
QKV = X @ Wqkv

这种融合会提高 GEMM 尺寸,提升 GPU 利用率,同时减少 kernel 数量。

LLM 里这很常见:

1
2
3
Q projection
K projection
V projection

融合成:

1
QKV projection

3. TensorRT 里常见的融合模式

TensorRT 会在构建 engine 时扫描计算图,识别固定 pattern,然后进行融合。

常见有:

1
2
3
4
5
6
7
8
9
10
Conv + Bias + Activation
Conv + BatchNorm + Activation
MatMul + Add
MatMul + Bias + GELU
Scale + Activation
ElementWise chain
Padding + Conv
Shuffle/Transpose 消除或合并
Quantize/Dequantize 融合
Attention 相关融合

对于 CNN:

1
2
3
Conv → BN → ReLU
Conv → Add → ReLU
DepthwiseConv → PointwiseConv 部分场景

对于 Transformer:

1
2
3
4
5
6
QKV projection fusion
Bias + GELU fusion
Residual + LayerNorm fusion
Attention fusion
GEMM epilogue fusion
Q/DQ quantization fusion

对于 TensorRT-LLM,还会有更专门的融合:

1
2
3
4
5
6
7
RMSNorm fusion
RoPE fusion
QKV GEMM fusion
Masked multi-head attention fusion
Paged KV cache attention
GEMM + activation fusion
MoE router/top-k/dispatch 相关融合

4. 融合到底省了什么

可以用一个简单例子看。

假设有三个逐元素算子:

1
2
3
y1 = x + a
y2 = y1 * b
y3 = relu(y2)

假设 tensor 有 $N$ 个元素。

不融合时大概显存访问:

1
2
3
4
5
6
7
8
Add:
read x, read a, write y1

Mul:
read y1, read b, write y2

ReLU:
read y2, write y3

忽略 cache,大约是:

1
读写次数 ≈ 8N

融合后:

1
y3 = relu((x + a) * b)

显存访问变成:

1
read x, read a, read b, write y3

大约:

1
读写次数 ≈ 4N

中间的 y1y2 不再落到显存。

在 GPU 上,很多 activation、normalization、elementwise 操作本身 FLOPs 很少,但是 tensor 很大,所以显存读写是主要开销。融合后性能提升通常来自减少 memory traffic,而不是减少数学运算。


5. Conv + BN + ReLU 的融合细节

这是最经典的融合。

训练图:

1
Conv → BatchNorm → ReLU

推理阶段可以分两步。

第一步,BN 折叠到 Conv 权重:

1
Conv + BN → Conv'

第二步,ReLU 作为 Conv 的 activation epilogue:

1
Conv' + ReLU → FusedConvReLU

最终:

1
Conv → BN → ReLU

变成:

1
FusedConvReLU

这种融合非常适合推理,因为 BatchNorm 的统计量固定。训练时通常不能这么简单融合,因为 BN 的 mean/variance 依赖当前 batch,并且要反向传播。


6. LayerNorm 为什么也适合融合

LayerNorm 的计算大概是:

$$
\mu = \frac{1}{H}\sum_{i=1}^{H} x_i
$$

$$
\sigma^2 = \frac{1}{H}\sum_{i=1}^{H}(x_i - \mu)^2
$$

$$
y_i = \gamma_i \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta_i
$$

朴素实现可能拆成很多 kernel:

1
2
3
4
5
6
7
reduce mean
reduce variance
sub
rsqrt
mul
scale
bias

融合后可以用一个或少数几个 kernel 完成:

1
FusedLayerNorm

好处是:

1
2
3
1. 减少多次读写 hidden states
2. 减少多个 reduction kernel
3. gamma/beta 应用可以和 normalization 合并

LLM 里每层都有 LayerNorm 或 RMSNorm,所以这类融合收益很明显。

RMSNorm 更简单:

$$
y_i = \gamma_i \frac{x_i}{\sqrt{\frac{1}{H}\sum_j x_j^2 + \epsilon}}
$$

因为没有减均值,计算更轻,更容易融合。


7. Attention fusion

标准 attention 逻辑大致是:

1
2
3
4
5
6
7
8
Q = XWq
K = XWk
V = XWv

S = QK^T / sqrt(d)
S = mask(S)
P = softmax(S)
O = PV

不融合时会产生多个中间矩阵:

1
2
3
4
5
Q, K, V
attention scores S
masked scores
softmax probabilities P
output O

问题是 attention score 的形状通常很大:

1
batch × heads × seq_len × seq_len

如果 seq_len 很长,中间矩阵显存开销巨大。

融合 attention 的目标是:

1
2
不把完整 S 和 P 长期写回显存
在 tile 级别完成 QK、mask、softmax、PV

FlashAttention 这类思想就是典型代表:把 attention 分块,在 SRAM/shared memory/register 中完成局部计算,避免完整 materialize attention matrix。

TensorRT-LLM 的 attention fusion、paged KV cache attention、本质上也围绕这个方向优化:

1
2
3
4
减少 HBM 读写
减少中间 attention matrix
优化 KV cache 访问
提升长序列解码吞吐

8. 融合和量化的关系

量化模型里常见结构:

1
Quantize → Dequantize → MatMul → Quantize → Dequantize → Add

如果不优化,Q/DQ 节点会很多。

TensorRT 会尝试识别 Q/DQ pattern,把量化 scale 融入 kernel 中。例如:

1
2
3
4
5
6
INT8 input
INT8 weight
INT32 accumulate
scale
activation
输出 FP16 或 INT8

可以融合成一个 quantized GEMM kernel:

1
Dequant + MatMul + Scale + Activation + Quant

也就是说,量化不是简单把 tensor 变小,还需要把 scale、zero point、dequant、requant 等操作融合进主算子,否则额外开销会吃掉收益。


9. 融合不是越多越好

有些融合会失败,或者不值得融合。

常见原因:

1
2
3
4
5
6
7
8
1. 中间结果被多个下游节点复用
2. shape 动态性太强
3. 算子之间 layout 不兼容
4. 精度约束不同,比如一个必须 FP32,一个可以 FP16
5. 算子太大,融合后 register pressure 过高
6. 融合后 occupancy 下降
7. plugin/backend 不支持该 pattern
8. 控制流或动态 Python 逻辑无法静态分析

例如:

1
2
A → B → C
↘ D

如果 B 的输出同时给 CD,那 A+B+C 的融合可能会破坏 D 的输入复用,TensorRT 需要判断是否值得。

还有一种情况是融合后 kernel 太复杂:

1
2
3
更多寄存器
更低 occupancy
更差 cache 行为

所以编译器通常有 cost model 或 tactic benchmarking,不是机械地把所有东西都融合。


10. TensorRT 是怎么决定融合的

TensorRT 大概经历这些阶段:

1
2
3
4
5
6
7
8
1. 解析 ONNX / Network Definition
2. 构建内部计算图
3. 常量折叠、shape 推导
4. 图 pattern matching
5. 图重写,把多个节点替换成 fused node
6. 为 fused node 选择 tactic/kernel
7. 根据 workspace、precision、shape profile 做 autotuning
8. 生成 engine

例如图里有:

1
2
3
4
5
Conv

BatchNormalization

Relu

TensorRT 可以重写成:

1
Convolution layer with fused scale and activation

而对于 Transformer:

1
2
3
4
5
MatMul

Add

GELU

可能变成:

1
fused GEMM epilogue

如果内置融合不支持,就可能需要 plugin。


11. Python backend 不会自动做这种融合

如果用 Triton Python backend 包 FireRedASR2-LLM,那么:

1
2
Triton Python backend
只是调用 Python/PyTorch 推理代码

它本身不会把:

1
Linear → Bias → GELU

自动融合成 TensorRT kernel。

真正发生融合的地方可能是:

1
2
3
4
5
6
PyTorch eager: 少量库内融合,比如 cuDNN/cuBLAS 内部
torch.compile: 可能做图捕获和 fusion
ONNX Runtime: 可能做图优化
TensorRT: build engine 时做 fusion
TensorRT-LLM: 专门做 LLM fusion
自定义 CUDA/CUTLASS kernel: 手动 fusion

所以如果目标是融合算子加速,通常需要把模型或模型子图交给一个 compiler/runtime,而不是只放进 Python backend。


12. 对 FireRedASR2-LLM 的实际意义

的模型大概率包括:

1
2
3
4
5
6
音频预处理
CMVN
ASR encoder
LLM decoder / Qwen2
tokenizer / decode
后处理

可以分成几块看:

ASR encoder

如果是标准 Transformer/Conformer encoder,可能有:

1
2
3
4
5
Conv/Subsampling
LayerNorm
Self-Attention
FFN
Residual

可优化点:

1
2
3
4
5
LayerNorm/RMSNorm fusion
QKV GEMM fusion
Attention fusion
Bias+GELU fusion
Conv+BN+Activation fusion

Qwen2-7B-Instruct

LLM 部分更适合 TensorRT-LLM 优化:

1
2
3
4
5
6
7
RMSNorm fusion
RoPE fusion
QKV fusion
GQA/MQA attention fusion
KV cache optimization
GEMM epilogue fusion
FP16/INT8/INT4 quantization

Python 预处理/后处理

这些部分不容易被 TensorRT 自动融合。比如:

1
2
3
4
kaldiio 读 cmvn
tokenizer
音频切分
字符串后处理

更适合做工程优化,例如缓存、批处理、异步流水线。


13. 如何判断融合是否有效

可以看三个指标:

1
2
3
1. kernel 数量是否减少
2. HBM read/write 是否减少
3. latency / throughput 是否改善

常用工具:

1
2
3
4
5
trtexec verbose 日志
TensorRT engine layer info
Nsight Systems 看 kernel launch 数量
Nsight Compute 看 memory throughput / occupancy
PyTorch profiler 看 eager 下算子数量

直观判断:

如果优化前看到很多小 kernel:

1
2
3
4
5
add_kernel
mul_kernel
relu_kernel
layernorm_kernel
transpose_kernel

优化后变成少数 fused kernel:

1
2
3
fused_bias_gelu
fused_layernorm
fused_attention

一般说明融合生效了。


14. 一个简化例子

未融合:

1
2
3
4
5
6
7
8
9
10
11
Input

MatMul

Add bias

GELU

Dropout / Identity in inference

MatMul

推理优化后:

1
2
3
4
5
Input

FusedMatMulBiasGELU

MatMul

如果第二个 MatMul 后面还有 residual add:

1
2
3
4
5
MatMul

Add residual

LayerNorm

可能继续变成:

1
FusedResidualLayerNorm

最终 Transformer block 里的许多小算子会被压缩成几个大 kernel。


15. 总结

融合算子的核心不是“数学上少算了多少”,而是:

1
2
3
4
5
6
7
减少 kernel launch
减少中间 tensor
减少 HBM 读写
提高 cache/register/shared memory 利用
把后处理塞进主算子 epilogue
把多个小 GEMM 合成大 GEMM
让 Tensor Core 更充分工作

对于 CNN,典型是:

1
Conv + BN + ReLU

对于 Transformer/LLM,典型是:

1
2
3
4
5
6
QKV projection fusion
GEMM + Bias + GELU
Residual + LayerNorm/RMSNorm
Attention fusion
RoPE fusion
Quant/Dequant fusion

对当前 FireRedASR2-LLM 来说,Python backend 只是服务化;真正想靠融合提速,需要进一步考虑 torch.compile、ONNX/TensorRT、TensorRT-LLM,或者把 ASR encoder 和 Qwen2 LLM 部分拆出来分别优化。