GEMM
GEMM 是 General Matrix Multiply,中文通常叫通用矩阵乘法。它是深度学习里最核心的计算形式之一。
它的标准形式是:
C = \alpha A B + \beta C
其中:
1 | A: m × k 矩阵 |
如果忽略 $\alpha$、$\beta$,最常见的理解就是:
1 | C = A @ B |
也就是矩阵乘矩阵。
1. GEMM 为什么重要
神经网络里的很多算子,最后都可以转成 GEMM。
例如全连接层 / Linear 层:
1 | y = x @ W + b |
如果:
1 | x: batch × hidden |
那么:
1 | y: batch × output |
这本质就是 GEMM,加上一个 bias。
Transformer 里的 Q、K、V 投影也是 GEMM:
1 | Q = X @ Wq |
MLP 也是 GEMM:
1 | H = GELU(X @ W1 + b1) |
Attention 里也有 GEMM:
1 | scores = Q @ K.transpose(-1, -2) |
所以 LLM 推理的大部分时间都花在矩阵乘法上。
2. GEMM 和 Linear 层的关系
PyTorch 里写:
1 | torch.nn.Linear(in_features, out_features) |
底层大致是:
1 | out = input @ weight.T + bias |
这就是 GEMM + bias add。
比如:
1 | input: [batch_size, hidden_size] |
实际矩阵乘法是:
1 | input @ weight.T |
也就是:
1 | [batch_size, hidden_size] × [hidden_size, out_features] |
在 LLM 里,这种计算极多。
3. GEMM 和卷积的关系
卷积也经常可以转成 GEMM。
比如 2D convolution:
1 | Conv2D(input, kernel) |
可以通过 im2col 把输入局部窗口展开成矩阵,然后把卷积核也展开成矩阵,最后做矩阵乘法。
也就是说:
1 | 卷积 ≈ im2col + GEMM |
当然现代 GPU 上不一定真的显式做 im2col,很多高性能卷积 kernel 会隐式完成这个过程,但核心思想仍然和矩阵乘法密切相关。
4. GEMM 为什么适合 GPU
矩阵乘法非常适合 GPU,因为它有三个特点:
1 | 1. 计算密集 |
例如计算:
1 | C = A @ B |
每个 C[i, j] 都是:
1 | A 的第 i 行 和 B 的第 j 列 的点积 |
不同的 C[i, j] 可以并行计算。GPU 可以把矩阵切成很多小块,也就是 tile,让不同线程块负责不同 tile。
在 NVIDIA GPU 上,FP16/BF16/INT8/FP8 GEMM 还可以用 Tensor Core 加速。这就是为什么 LLM 推理里经常强调:
1 | FP16 GEMM |
这些都是为了让矩阵乘法跑得更快。
5. GEMM 和 TensorRT 的关系
TensorRT 里很多 layer 最终都会映射到高性能 GEMM kernel。
例如:
1 | MatMul |
TensorRT 会做几件事:
1 | 1. 选择合适的 GEMM 实现 |
比如原始图是:
1 | MatMul → Add Bias → GELU |
TensorRT 可能变成:
1 | Fused GEMM + Bias + GELU |
也就是 GEMM 算完后,不把中间结果先写回显存,而是在 GEMM 的输出阶段直接做 bias 和 GELU。
这就是之前说的 GEMM epilogue fusion。
6. 一个具体例子
假设有一个 Linear 层:
1 | x = torch.randn(8, 4096) |
这里:
1 | A = x: 8 × 4096 |
这就是一个 GEMM。
它的计算量大约是:
1 | 2 × m × n × k |
为什么是乘 2?因为矩阵乘法里每一步通常有一次乘法和一次加法:
1 | a * b + acc |
7. GEMM、GEMV、Batched GEMM 的区别
GEMM 是矩阵乘矩阵:
1 | [m, k] × [k, n] → [m, n] |
GEMV 是矩阵乘向量:
1 | [m, k] × [k] → [m] |
Batched GEMM 是一批矩阵乘法:
1 | [batch, m, k] × [batch, k, n] → [batch, m, n] |
LLM prefill 阶段通常 GEMM 很大,GPU 利用率较高。
LLM decode 阶段 batch 小、seq 每次只生成一个 token,很多计算更接近 GEMV 或小 GEMM,所以 GPU 利用率容易下降。这也是为什么 TensorRT-LLM、vLLM、continuous batching、paged attention 这些系统要特别优化 decode 阶段。
8. 一句话总结
GEMM 就是:
1 | 通用矩阵乘法 C = αAB + βC |
它是 Linear、Attention、MLP、部分 Conv 的底层核心计算。深度学习推理加速很大程度上就是在优化 GEMM:让它用 Tensor Core、更低精度、更好的 layout、更少的显存读写,以及把 bias、activation、quant/dequant 等操作融合进 GEMM 的输出阶段。
