“融合乘加”(Fused Multiply-Add,简称 FMA)是计算机体系结构和数值计算中一项非常重要的技术。它指的是一条单一的硬件指令,能够同时执行一次乘法和一次加法运算,其数学形式为:
其中 a
, b
, c
是操作数(通常是浮点数),result
是结果。
一、关键特点和优势
-
更高的精度:
-
这是 FMA 最重要的优势。
-
在传统的分开执行乘法和加法时(
temp = a * b; result = temp + c;
),乘法a * b
的结果temp
会先被四舍五入到目标浮点格式(如单精度 float 或双精度 double),然后再将这个经过舍入的值加到c
上,结果再次被舍入。 -
FMA 指令的关键在于:它在执行
(a * b) + c
的过程中,只进行一次最终的舍入操作。 中间乘积a * b
在内部使用更高的精度(通常是目标精度的两倍)或范围来表示,然后再与c
相加,最后才将和舍入到目标精度。 -
结果: 这显著减少了舍入误差,使得计算结果更接近数学上的精确值
(a * b) + c
。这在科学计算、金融建模、图形渲染等对精度要求高的领域至关重要。
-
-
更高的性能:
-
减少指令数: FMA 将原本需要两条指令(乘法指令 + 加法指令)完成的操作合并成一条指令。
-
提高吞吐量: 现代处理器通常有专门执行 FMA 的硬件单元(执行流水线)。一条 FMA 指令通常可以在一个时钟周期内完成(或者具有与单条乘法或加法指令相似的延迟)。这比执行两条独立的指令快得多。
-
减少延迟: 虽然单条 FMA 指令的延迟可能略高于单条乘法或加法指令,但它仍然远低于两条指令(乘法+加法)的总延迟,尤其是考虑到指令调度、发射和结果等待的开销。
-
提高指令级并行性: 更少的指令意味着处理器可以更容易地同时执行更多其他指令,充分利用处理器的资源。
-
-
更高的能效:
-
执行一条指令通常比执行两条指令消耗的能量更少,尤其是在高性能计算环境中,FMA 能显著降低功耗。
-
二、应用场景
FMA 是现代数值计算和并行计算的基石,广泛应用于:
-
点积计算: 向量点积
sum(a_i * b_i)
的核心就是大量的乘加操作。 -
矩阵乘法: 矩阵乘法的每个元素计算本质上就是点积。
-
多项式求值: 使用霍纳法则求值时,每一步都是乘加操作。
-
信号处理: 滤波器(如 FIR、IIR)计算涉及大量乘积累加。
-
计算机图形学: 3D 变换(矩阵-向量乘法)、光照计算等。
-
物理模拟和科学计算: 求解微分方程、线性代数运算等。
-
机器学习: 神经网络训练和推理中的矩阵运算、卷积等核心计算极度依赖 FMA。
1.具体场景:
场景 | FMA 是否更高效? | 原因 |
---|---|---|
只做乘法 a * b | ❌ 否 | 普通乘法更轻便,FMA 多此一举 |
做乘加 a * b + c | ✅ 是 | FMA 合并操作更快更准 |
做大规模点积、卷积等 | ✅ 是 | FMA 能显著提升计算效率和精度 |
2.为什么大规模的点积、卷积会使用 FMA?
因为这两种运算的核心结构都包含了 “乘加” 的模式,即:
每一步都在重复执行:
这正是 FMA(Fused Multiply-Add) 最擅长的事情!
✅ 1、点积:典的 FMA 场景
点积计算公式是:
伪代码:
float result = 0.0;
for (int i = 0; i < n; ++i) {
result += a[i] * b[i];
}
每一步都在做:
result = result + (a[i] * b[i])
👉 完美匹配 FMA(a[i], b[i], result)
!
使用 FMA,可以让这个过程更快、更准,因为它:
-
减少了运算次数(一条指令代替两条)
-
减少了舍入误差(只舍入一次)
-
利用现代 CPU/GPU 的 SIMD(Single Instruction, Multiple Data) 并行 FMA 指令(如 AVX, NEON, CUDA)
✅ 2、卷积:也是连续的乘加
卷积(1D/2D)数学上也是这样:
伪代码(简化):
for (int i = 0; i < output_size; ++i) {
float sum = 0.0;
for (int j = 0; j < kernel_size; ++j) {
sum += input[i + j] * kernel[j];
}
output[i] = sum;
}
这和点积本质上一样,是一个滑动窗口点积运算 —— 每次都可以用 FMA 来优化。
3.CUDA中,我们需要显式调用FMA吗? 还是GPU会自动使用它?
通常不需要你显式调用 FMA,CUDA 编译器(NVCC)会自动使用它,只要你的代码里存在乘加模式,比如:
c += a * b;
只要:
-
数据是浮点数(特别是 FP16、TF32、FP32)
-
你的 GPU 支持 Tensor Core 或 FMA
-
编译器优化开启(默认就是)
✅ NVCC 会自动把它转成 fused multiply-add 指令,例如 fma.rn.f32
(即 fused multiply-add, round to nearest, float32)。
编译器自动使用FMA的例子:
给你一个简单的 kernel 函数:
__global__ void fma_test(float* a, float* b, float* c) {
int i = threadIdx.x;
c[i] = a[i] * b[i] + c[i];
}
编译后(nvcc -arch=sm_75 -ptx
),会生成类似这样的 PTX:
fma.rn.f32 %f3, %f1, %f2, %f3;
表示已启用 fused multiply-add。
(关于 什么是PTX ,见最后的附加。)
你也可以手动使用 CUDA 提供的 fma()
函数:
如果你想更明确地使用 FMA 指令,可以这样写:
__device__ float my_kernel(float a, float b, float c) {
return fmaf(a, b, c); // 调用的是 fused multiply-add(float 版本)
}
支持的类型:
-
fma()
:double -
fmaf()
:float -
fmahf()
:half
三、概括
硬件支持
-
CPU:
-
x86/x86-64: Intel Haswell (2013) 及以后的 CPU 支持 FMA3 指令集扩展。一些 AMD CPU 也支持 FMA3/FMA4。
-
ARM: ARMv8-A (AArch64) 及更高版本(如 Cortex-A53/A57/A72/A76/X1/X2, Neoverse 系列)普遍支持 FMA。ARM NEON (SIMD) 和 SVE/SVE2 指令集也包含 FMA 指令。
-
PowerPC: POWER ISA 很早就支持 FMA。
-
-
GPU:
-
现代 GPU (NVIDIA CUDA cores, AMD Stream Processors) 的 ALU 核心通常设计为在每个时钟周期执行一条 FMA 指令,这是其强大浮点性能的关键。
-
软件支持
-
编译器(如 GCC, Clang, MSVC)可以识别代码中的乘加模式(例如
a*b + c
)并自动生成 FMA 指令(需要开启相应的优化选项,如-mfma
,-march=native
等)。 -
程序员也可以使用编译器提供的特定函数(intrinsics)直接调用 FMA 指令(例如
_mm_fmadd_ps
,vfmaq_f32
等)。
补充说明:
-
在 C/C++ 中,FMA 可以通过
std::fma(a, b, c)
实现(需要<cmath>
)。 -
在底层汇编中,指令可能叫
FMA
,VFMA
,FMLA
(ARM)等。 -
在现代编译器中,开启优化(如
-O2
或-ffast-math
)时,编译器可能会自动将乘加转换为 FMA(前提是硬件支持)。
总结
Fused Multiply-Add (FMA) 是一项通过将乘法和加法操作合并到一条硬件指令中来实现的核心技术。它提供了显著更高的计算精度(通过减少舍入误差)和大幅提升的性能(通过减少指令数、提高吞吐量和并行度),同时还能提高能效。它是现代 CPU、GPU 和数值密集型应用(科学计算、AI、图形学等)高性能和高精度计算的基石。当你看到 (a * b) + c
这样的表达式在现代硬件上运行时,它很可能被编译成一条高效的 FMA 指令。
附加:
💡 PTX是什么?
Parallel Thread Execution,它是 NVIDIA 为 CUDA GPU 定义的一种中间汇编语言,介于你的 C++ CUDA 代码和底层 GPU 机器码之间。
🧩 类比理解:
阶段 | 类比于 CPU 的流程 | CUDA 中的术语 |
---|---|---|
高级语言 | C/C++、Python | CUDA C++ |
中间代码/汇编 | x86 汇编 | PTX |
机器码/指令集 | x86 指令(二进制) | SASS(GPU指令集) |
PTX 是 NVIDIA GPU 的虚拟 ISA(指令集架构),和 Java 的 bytecode 有点像。
🧰 PTX 有什么用?
-
NVCC 编译器会先生成 PTX
-
当你运行
nvcc
编译 CUDA 程序时,它先生成.ptx
文件(PTX 汇编) -
然后 GPU 驱动会把 PTX 转成适配具体硬件的 SASS(机器码)
-
-
跨架构兼容性
-
PTX 是跨架构的(类似 JVM 的 bytecode),同一个 PTX 可以在多个 GPU 上运行(由驱动再翻译)
-
可以延迟编译成最终 GPU 的指令(Just-In-Time 编译)
-
-
调试和性能分析
-
查看你的代码有没有被优化为
fma.rn.f32
等指令 -
检查内存访问是否合理
-
📄 示例:PTX 代码长什么样?
假设你写了这个 CUDA kernel:
__global__ void add(float* a, float* b, float* c) {
int i = threadIdx.x;
c[i] = a[i] + b[i];
}
使用命令编译出 PTX:
nvcc -arch=sm_75 -ptx add.cu
生成的 add.ptx
可能包含:
ld.global.f32 %f1, [%r1];
ld.global.f32 %f2, [%r2];
add.f32 %f3, %f1, %f2;
st.global.f32 [%r3], %f3;
这段 PTX 是浮点加载、加法和存储操作。
📦 PTX 和 SASS 的区别
类型 | 解释 | 谁生成它 |
---|---|---|
PTX | 虚拟汇编,跨 GPU 架构 | nvcc 生成 |
SASS | 真正的 GPU 指令集 | 驱动(JIT 编译)根据 PTX 生成 |
你可以用 cuobjdump --dump-sass
或 nvdisasm
查看 SASS。
🧠 总结一句话:
PTX 是 CUDA 程序编译后生成的虚拟汇编语言,它是连接 CUDA 代码和 GPU 硬件之间的桥梁。