MLA(1):从代码角度学习和彻底理解 DeepSeek MLA 算法
从代码角度深入理解 DeepSeek MLA 算法。从代码角度详细解析 MLA(Multi-head Latent Attention)算法的核心思想、ROPE 位置编码的兼容性问题,以及如何通过矩阵吸收来优化 KV Cache。
转载提示:本文转载自 bbruceyuan 原文。原作者:bbruceyuan。原文发布日期:2025-02-05。本文已按站点规范移除原文中的推广/导流内容,仅保留技术分析与示例。 在阅读本文之前,强烈建议先阅读原始paper和苏剑林的解读 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA,先在脑子中对于 MLA(Multi-head Latent Attention)有一个大概的印象,然后通过阅读代码理解下面三个问题MLA 大概是个什么东西,核心目标是节约 kv cache可以理解为效果更好的 Group Query Self Attenion。为什么 MLA 算法和 ROPE 位置编码不兼容,MLA 算法是通过维度分离来实现位置编码也就是说一部分 Q 和 K 专门做位置编码,称为 $q_{rope}$ 和 $k_{rope}$,一部分不做位置编码,称为 $q_{nope}$ 和 $k_{nope}$两部分做 concat 后去计算 $q = concat(q_{rope}, q_{nope})$ 和 $k = concat(k_{rope}, k_{nope})$ 的 attention weight;然后计算 qkv 最终的输出,具体见 代码;MLA 既然能节约 kv cache,那么具体是怎么通过矩阵吸收来做工程实现这部分Paper以及官方开源实现没有给出 如果不想看文字,可以看B站手把手教学视频: Part 1: 从零复现 DeepSeek MLA 算法-无矩阵吸收版 Part 2: 从零手撕 DeepSeek MLA 算法-矩阵吸收版 也欢迎关注我的 github repo: LLMs-zero-to-hero原始的 MLA 算法 一图胜千言,先通过 DeepSeek-V2/3 对应的官方模型图理解 MLA 算法。 llms-zero-to-hero-deepseek-v3-model-architecture 一些前置函数,主要是两个函数,一个 RMSNorm 的实现以及 ROPE 函数的实现,因为本次博客不涉及位置编码的解读,因此可以先简单把 ROPE 理解为一个函数,应用这个函数之后这部分张量(Tensor)就带有了位置编码的作用。class DeepseekV2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class DeepseekV2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) # 较小索引位置对应较低频率 # 较大的索引位置有较高的频率 # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtyp…
正在初始化 WebAssembly 引擎…
首次编译原生模块可能需要数秒
就绪后,页面交互将以接近原生的速度运行