理解 KV Cache 与 Prompt Caching:LLM 推理加速的核心机制
这篇文章系统梳理了 KV Cache 与 Prompt Cache 的核心机制,重点解释了 Prefill/Decode 的性能差异以及对 Agent 场景的设计约束。核心问题与价值从自回归生成的重复计算问题出发,说明 KV Cache 为什么是 LLM 推理的基础优化。给出 KV Cache 的显存公式与直觉,帮助读…
转载提示:本文转载自 bbruceyuan 原文。原作者:bbruceyuan。原文发布日期:2026-02-21。本文已按站点规范移除原文中的推广/导流内容,仅保留技术分析与示例。0. 阅读收获 (takeaway) 读完本文,你将了解:KV Cache 的原理以及它为什么对 LLM 推理如此重要Prefill 与 Decode 两个推理阶段的区别Compute Bound 与 Memory Bound 背后的直觉一个很好的问题:Prefill 阶段为什么需要计算所有 token 的 Q?Prompt Caching(前缀缓存)的工作原理 KV Cache 和 Prompt Cache 对于 Agent 设计的影响:Agent 系统中的 Prompt Caching 设计(上):Cache 破坏、Prompt 布局与工具管理 —— 为什么 Agent 更需要 Cache、什么会破坏 Cache、三家工具管理方案对比Agent 系统中的 Prompt Caching 设计(下):上下文管理与子代理架构 —— 上下文压缩、Plan 模式演进、子代理 Cache 友好设计1. 什么是 KV Cache?1.1 Autoregressive 生成的重复计算问题 大语言模型(LLM)的文本生成是 自回归(autoregressive) 的:每次只生成一个 token,然后把这个 token 拼到已有序列后面,再预测下一个。 用伪代码表示就是:# 自回归生成的朴素实现 output_tokens = [] for step in range(max_new_tokens): # 每一步都要把 整个序列 送进模型 logits = model(input_tokens + output_tokens) next_token = sample(logits[-1]) # 只用最后一个位置的 logits output_tokens.append(next_token) Q: 问题出在哪? 每一步生成,模型都要对所有历史 token 重新做 Attention 计算——包括 Q、K、V 矩阵乘法。但对于已经出现过的 token,它们的 K 和 V 其实不会变(因为参数没变、token 没变),唯一在变的只有 "最新生成的那个 token" 对应的 Q、K、V。 这就引出了一个自然的优化思路:能不能把已经算过的 K 和 V 缓存起来,下次直接用?1.2 KV Cache 的核心思想 KV Cache 的核心思想非常直接: 把每一层 Attention 中、每个已生成 token 对应的 K 向量和 V 向量缓存下来。后续生成新 token 时,只需要计算新 token 自己的 Q、K、V,然后将新的 K、V 追加到缓存中,用缓存里的完整 K、V 序列做 Attention。 这样一来,生成第 $t$ 个 token 时,Attention 的计算从 $O(t \times d)$(重算所有 token 的 K、V)降低到 $O(d)$(只算 1 个新 token 的 K、V),避免了绝大部分重复计算。 用带 KV Cache 的伪代码表示:# 带 KV Cache 的生成 kv_cache = {} # 每一层缓存 K, V for step in range(max_new_tokens): if step == 0: # 第一步:处理所有 input tokens,填充 cache logits, kv_cache = model(input_tokens, kv_cache=None) else: # 后续步:只送入上一步生成的 1 个 token logits, kv_cache = model([last_token], kv_cache=kv_cache) next_token = sample(logits[-1]) last_token = next_token 1.3 KV Cache 显存占用 KV Cache 不是免费的——它用显存换计算。随着生成序列变长,KV Cache 占用的显存会线性增长。 具体公式(假设 float16 存储): $$\text{KV Cache 显存} = 4blh(s + n) \text{ bytes}$$ 其中:$b$ = batch size$l$ = Transformer 层数$h$ = hidden size$s$ = 输入序列长度$n$ = 输出序列长度4 = 2(K 和 V)× 2(float16 占 2 bytes) 这个公式的详细推导和具体数值例子,可以参考我之前的文章 LLM 大模型训练-推理显存占用分析。这里只需要记住一个直觉:序列越长,KV Cache 越大。这也是为…
正在初始化 WebAssembly 引擎…
首次编译原生模块可能需要数秒
就绪后,页面交互将以接近原生的速度运行