LLM 大模型训练-推理显存占用分析
根据模型的参数大小,预估模型训练和推理过程中的显存占用情况,包括参数占用显存大小、优化器占用显存大小...KV Cache 和 中间激活值的计算方式
转载提示:本文转载自 bbruceyuan 原文。原作者:bbruceyuan。原文发布日期:2024-10-06。本文已按站点规范移除原文中的推广/导流内容,仅保留技术分析与示例。阅读完本文可以收获什么?训练和推理的时候,占用显存的内容到底有哪里?xB 的模型,预估推理需要多少显存?要全参数训练一个 xB 的模型,需要多少显存?为什么混合精度训练可以节约显存?使用 deepspeed 后,每个卡占用的显存是多少?基础知识1 Byte = 8 bits, 1 KB = 1024 Bytes, 1 MB = 1024 KB, 1 GB = 1024 MB1 float64 = 8 Bytes = 64 bits, 这是双精度浮点数1 float32 = 4 Bytes = 32 bits,这是单精度浮点数,也称为 fp321 float16 = 2 Bytes = 16 bits,这是半精度浮点数,也称为 fp161 bf16 = 2 Bytes = 16 bits,这是 Brain Floating Point 格式,也称为 bf16显存占用 强烈推荐阅读大佬的分析,本文强参考于 https://zhuanlan.zhihu.com/p/624740065推理 模型推理一共有两部分内容会占用 GPU 显存,模型参数和 KV cache。其中假设模型的参数是 $\theta$,那么推理时候占用的显存是 $2 \Phi$,这是因为现在 HuggingFace 中大部分模型的参数都保存为 BF16,如果没有特殊的必要,不会用 fp32 加载。KV cache,假设输入序列的长度为 s ,输出序列的长度为 n ,以float16来保存KV cache,那么KV cache的峰值显存占用大小为 $b(s+n)h∗l∗2∗2=4blh(s+n)$ 。这里第一个2 表示K/V cache,第二个2表示float16占2个bytes。 粗略的预估,模型推理需要的内存为: 1.2 倍的模型参数内存 = $1.2 \times 2 \Phi = 2.4 \Phi$,以 7B 模型为例,那么推理需要的内存大概是: 16.8 GB。 相对精确的预估:按照上面的公式进行计算 注意⚠️:推理的时候并不需要保存激活值,看到有的博客说需要保存激活值是错的。训练 一般来说,现在都用混合精度训练,因此所有的分析都按照混合精度训练进行分析,而且按照 AdamW 优化起进行分析。训练的时候 GPU 显存占用一共包括 4 个部分:模型参数,梯度,优化器状态,激活值。 假设模型参数为 $\Phi$。模型参数:fp32 参数 + bf16 参数 = $(4 + 2 )\Phi$ = $6\Phi , bytes$ 。梯度分为两种情况:是否开启 gradient accumulation开启梯度累积:要同时保持 fp32 和 bf16 = $6\Phi , bytes$没有开启梯度累积:保持 bf16 即可,占用显存为 $2\Phi , bytes$,但反向传播需要变成 fp32 计算,因此峰值还是需要 $4\Phi$。优化器一节动量 fp32 和二阶动量 fp32,一共为 $(4 + 4) \Phi = 8\Phi , bytes$激活值(bf16): $(34bsh + 5bs^2a)\ast l , bytes$ 因此在训练中,单卡需要的内存为: $$20 \Phi + (34bsh + 5bs^2a)\ast l$$假设使用 DeepSpeed 训练 如果使用 DeepSpeed,那么应该怎么计算每张卡需要的显存呢?ZeRO1,切分优化器(一阶动量 + 二阶动量 + fp32 参数副本) / 卡数 + 梯度 + bf16 参数 + 激活值需要注意:fp32 参数保存在优化器ZeRO2,切分梯度(一阶动量 + 二阶动量 + fp32 参数副本 + 梯度) / 卡数 + bf16 参数 + 激活值ZeRO3,切分模型参数 - (一阶动量 + 二阶动量 + fp32 参数副本 + 梯度 + bf16 参数) / 卡数 + 激活值 具体可以看图(图是按照 Zero-offload 说的 $16\Phi$ 预估的,但是 ZeRO-Infinity 说需要 $20\Phi$,但是差别应该不是特别大): 大模型训练推理时候的显存占用计算-20241006094016066.webpFAQ 疑问 1 🤔:除去 kvcache 和激活值,模型参数部分到底占用 $16\Phi ; or ; 18\Phi ; or ; 20\Phi$ 是比较有争议的并且非常让人疑惑,为什么会出现这个情况?试图解答:英伟达论文和 deepspeed 的实现中,都是按照 $20\Phi$实现的,也就是说会同…
正在初始化 WebAssembly 引擎…
首次编译原生模块可能需要数秒
就绪后,页面交互将以接近原生的速度运行