GEMMGeneral Matrix Multiply,中文通常叫通用矩阵乘法。它是深度学习里最核心的计算形式之一。

它的标准形式是:

C = \alpha A B + \beta C

其中:

1
2
3
4
A: m × k 矩阵
B: k × n 矩阵
C: m × n 矩阵
α, β: 标量

如果忽略 $\alpha$、$\beta$,最常见的理解就是:

1
C = A @ B

也就是矩阵乘矩阵。


1. GEMM 为什么重要

神经网络里的很多算子,最后都可以转成 GEMM。

例如全连接层 / Linear 层:

1
y = x @ W + b

如果:

1
2
x: batch × hidden
W: hidden × output

那么:

1
y: batch × output

这本质就是 GEMM,加上一个 bias。

Transformer 里的 Q、K、V 投影也是 GEMM:

1
2
3
Q = X @ Wq
K = X @ Wk
V = X @ Wv

MLP 也是 GEMM:

1
2
H = GELU(X @ W1 + b1)
Y = H @ W2 + b2

Attention 里也有 GEMM:

1
2
scores = Q @ K.transpose(-1, -2)
output = softmax(scores) @ V

所以 LLM 推理的大部分时间都花在矩阵乘法上。


2. GEMM 和 Linear 层的关系

PyTorch 里写:

1
torch.nn.Linear(in_features, out_features)

底层大致是:

1
out = input @ weight.T + bias

这就是 GEMM + bias add。

比如:

1
2
3
input:  [batch_size, hidden_size]
weight: [out_features, hidden_size]
output: [batch_size, out_features]

实际矩阵乘法是:

1
input @ weight.T

也就是:

1
2
[batch_size, hidden_size] × [hidden_size, out_features]
= [batch_size, out_features]

在 LLM 里,这种计算极多。


3. GEMM 和卷积的关系

卷积也经常可以转成 GEMM。

比如 2D convolution:

1
Conv2D(input, kernel)

可以通过 im2col 把输入局部窗口展开成矩阵,然后把卷积核也展开成矩阵,最后做矩阵乘法。

也就是说:

1
卷积 ≈ im2col + GEMM

当然现代 GPU 上不一定真的显式做 im2col,很多高性能卷积 kernel 会隐式完成这个过程,但核心思想仍然和矩阵乘法密切相关。


4. GEMM 为什么适合 GPU

矩阵乘法非常适合 GPU,因为它有三个特点:

1
2
3
1. 计算密集
2. 数据访问模式规则
3. 可以切成很多 tile 并行计算

例如计算:

1
C = A @ B

每个 C[i, j] 都是:

1
A 的第 i 行 和 B 的第 j 列 的点积

不同的 C[i, j] 可以并行计算。GPU 可以把矩阵切成很多小块,也就是 tile,让不同线程块负责不同 tile。

在 NVIDIA GPU 上,FP16/BF16/INT8/FP8 GEMM 还可以用 Tensor Core 加速。这就是为什么 LLM 推理里经常强调:

1
2
3
4
FP16 GEMM
BF16 GEMM
INT8 GEMM
FP8 GEMM

这些都是为了让矩阵乘法跑得更快。


5. GEMM 和 TensorRT 的关系

TensorRT 里很多 layer 最终都会映射到高性能 GEMM kernel。

例如:

1
2
3
4
5
6
MatMul
FullyConnected
Linear
QKV projection
MLP projection
Attention score calculation

TensorRT 会做几件事:

1
2
3
4
5
1. 选择合适的 GEMM 实现
2. 选择数据 layout
3. 使用 Tensor Core
4. 融合 bias、activation、scale 等 epilogue
5. 针对具体 shape 选择 tactic

比如原始图是:

1
MatMul → Add Bias → GELU

TensorRT 可能变成:

1
Fused GEMM + Bias + GELU

也就是 GEMM 算完后,不把中间结果先写回显存,而是在 GEMM 的输出阶段直接做 bias 和 GELU。

这就是之前说的 GEMM epilogue fusion


6. 一个具体例子

假设有一个 Linear 层:

1
2
3
x = torch.randn(8, 4096)
w = torch.randn(4096, 11008)
y = x @ w

这里:

1
2
3
A = x: 8 × 4096
B = w: 4096 × 11008
C = y: 8 × 11008

这就是一个 GEMM。

它的计算量大约是:

1
2
3
2 × m × n × k
= 2 × 8 × 11008 × 4096
≈ 721 million FLOPs

为什么是乘 2?因为矩阵乘法里每一步通常有一次乘法和一次加法:

1
a * b + acc

7. GEMM、GEMV、Batched GEMM 的区别

GEMM 是矩阵乘矩阵:

1
[m, k] × [k, n] → [m, n]

GEMV 是矩阵乘向量:

1
[m, k] × [k] → [m]

Batched GEMM 是一批矩阵乘法:

1
[batch, m, k] × [batch, k, n] → [batch, m, n]

LLM prefill 阶段通常 GEMM 很大,GPU 利用率较高。
LLM decode 阶段 batch 小、seq 每次只生成一个 token,很多计算更接近 GEMV 或小 GEMM,所以 GPU 利用率容易下降。这也是为什么 TensorRT-LLM、vLLM、continuous batching、paged attention 这些系统要特别优化 decode 阶段。


8. 一句话总结

GEMM 就是:

1
通用矩阵乘法 C = αAB + βC

它是 Linear、Attention、MLP、部分 Conv 的底层核心计算。深度学习推理加速很大程度上就是在优化 GEMM:让它用 Tensor Core、更低精度、更好的 layout、更少的显存读写,以及把 bias、activation、quant/dequant 等操作融合进 GEMM 的输出阶段。