MLA(2):从代码和公式角度理解 DeepSeek MLA 的矩阵吸收 (Projection Absorption)
从代码角度深入理解 DeepSeek MLA 算法。从代码角度详细解析 MLA(Multi-head Latent Attention)算法的核心思想,如何通过矩阵吸收来优化 KV Cache。
转载提示:本文转载自 bbruceyuan 原文。原作者:bbruceyuan。原文发布日期:2025-03-16。本文已按站点规范移除原文中的推广/导流内容,仅保留技术分析与示例。基础原理 这里假设读者对于 MLA有一定的了解,只是不清楚 MLA 算法的实现,关于原版的 MLA 具体实现可以见 从代码角度学习和彻底理解 DeepSeek MLA 算法,视频解读见:完全从零实现DeepSeek MLA算法(MultiHead Latent Attention)-(无矩阵吸收版) deepseek-mla-矩阵吸收之迷-20250316140034131 上面的公式详细的解释了MLA 的计算过程,但这是为了后续代码讲解矩阵吸收回顾使用。 如果不想看文字,可以看B站手把手教学视频: Part 1: 从零复现 DeepSeek MLA 算法-无矩阵吸收版 Part 2: 从零手撕 DeepSeek MLA 算法-矩阵吸收版 欢迎关注我的 github repo: LLMs-zero-to-heroCacheDecompressed (CD) 在原始的官方 huggingface 的实现中(852行开始),kv cache 缓存的是完整的 kv cache,也就是升维之后且应用了 RoPE 位置编码的 kv,而不是压缩后的 $C_t^{KV}$。具体实现见:def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ): ... # 注意这里的 compressed_kv 是计算出来的 # 实际只要缓存这个就行,不行看是 kv states compressed_kv = self.kv_a_proj_with_mqa(hidden_states) # 此处compressed_kv 对应公式中的 c_t^{KV} compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) ... # key shape is: (batch, seq_len, num_head, nope_dim + rope_dim) key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe # value shape is (batch, seq_len, num_head, value_dim) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) ... 注意代码中 shape 的注释,通过 shape 可以了解缓存的完整的 kv cacheCache Compressed_kv (CC)# CacheCompressed def forward(self, hidden_states_q: torch.Tensor, q_position_ids: torch.LongTensor, compressed_kv: torch.Tensor): ... kv_seq_len = compressed_kv.size(1) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, kv_seq_len, 1, self.qk_rope_head_di…
正在初始化 WebAssembly 引擎…
首次编译原生模块可能需要数秒
就绪后,页面交互将以接近原生的速度运行