LLaMA 源码解读
paperreading
本文字数:552 字 | 阅读时长 ≈ 2 min

LLaMA 源码解读

paperreading
本文字数:552 字 | 阅读时长 ≈ 2 min

1. LLaMA 源码解读

LLaMA 作为一个开源社区中广受欢迎的模型,其源码实现值得深入研究。本文将以 Huggingface 中的实现为例,详细解读 LLaMA 的核心组件和实现细节。

1.1 整体架构

LLaMA 是一个基于 Transformer 架构的因果语言模型,主要由以下几个关键组件构成:

  1. 词向量嵌入层 (Token Embeddings)
  2. 多层 Transformer Decoder 层
  3. 最终的输出层范化 (RMSNorm)

其特色在于:

1.2 核心组件解读

LlamaModel

LlamaModel 是整个模型的基础类,负责初始化和组织各个组件。让我们看看其关键实现:

class LlamaModel(LlamaPreTrainedModel):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # 词向量嵌入层
        self.embed_tokens = nn.Embedding(
            config.vocab_size,  # 词表大小
            config.hidden_size, # 词向量维度
            self.padding_idx    # padding token的索引
        )
        
        # Transformer解码器层
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
        ])
        
        # 最终的范化层
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

这里的关键点是:

  1. embed_tokens: 将输入 token 转换为连续的向量表示
  2. layers: 包含多个 Transformer 解码器层
  3. norm: 最终的 RMSNorm 范化层

LlamaDecoderLayer

每个解码器层包含以下核心组件:

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        
        # 自注意力层
        self.self_attn = LlamaAttention(config=config)
        
        # 前馈神经网络
        self.mlp = LlamaMLP(config)
        
        # 两个RMSNorm层
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

前向传播过程遵循:

  1. 输入先经过 RMSNorm
  2. 进行自注意力计算
  3. 残差连接
  4. 再次 RMSNorm
  5. 经过 MLP 层
  6. 最后一次残差连接

1.3 特色组件详解

RMSNorm

相比传统的 LayerNorm,RMSNorm 移除了均值计算和偏置项,只保留方差的归一化:

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states):
        # 计算均方根
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
        return self.weight * hidden_states

这种简化的范化方式不仅计算更快,而且在实践中表现良好。

4月 06, 2025
3月 10, 2025
12月 31, 2024