您的当前位置:首页正文

RoPE + 位置线性内插

2024-11-27 来源:个人技术集锦

RoPE 位置编码在超出一定的序列长度后,模型生成的 PPL 指标会爆炸,因此直接外推的效果很差。Meta 的研究团队在论文《》中提出了“位置线性内插”(Position Interpolation,PI)方案,来扩展 LLM 的 context length。

关于 RoPE 相对位置编码,可参考我的上一篇博客 。

实现方式

将预测的长文本位置缩放到训练长度范围以内,具体流程如下:

  • 第一张图的左侧蓝色区域:这部分是 LLM 预训练的最大序列长度,蓝色点表示输入的位置索引,它们都在 0 - 2048 范围内。
  • 第一张图的右侧粉色区域:这部分是长度外推后的区域,这些位置对于模型来说是“未见过的”,预训练期间没有得到训练。
  • 第二张图蓝色区域:通过位置线性内插的位置,将 0 - 4096 位置区域缩放到 0 - 2048 位置区域,通过这种方式将所有的位置索引映射回模型预训练时的范围,这些范围模型是“见过的”,并且得到训练。例如,位置 600 缩放到 300,位置 3100 缩放到 1550。

位置线性内插的核心思想是通过缩放位置索引,使得模型能够处理比预训练时更长的序列,而不损失太多的性能。其数学表达式如下所示:

f ′ ( x , m ) = f ( x , m L L ′ ) f'(x, m) = f(x, \frac{mL}{L'}) f(x,m)=f(x,LmL)

其中,x 是 token embedding、m 是位置索引,L’ 是扩展后的序列长度,L 是训练时的序列长度。

深入研究

g ( x m , x n , m − n ) = R e [ ( W q x m ) ( W k x n ) ∗ e i ( m − n ) θ i ] g(x_m, x_n, m - n) = Re[(W_q x_m) (W_k x_n) * e^{i(m - n)\theta_i}] g(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θi]

根据 RoPE 相对位置编码的数学表达式,加上位置线性内插后,m 和 n 同乘上 L/L’,可以表示为:

g ( x m , x n , m − n ) = R e [ ( W q x m ) ( W k x n ) ∗ e i ( m − n ) L L ′ θ i ] g(x_m, x_n, m - n) = Re[(W_q x_m) (W_k x_n) * e^{i(m - n)\frac{L}{L'}\theta_i}] g(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)LLθi]

通常, θ i = 1000 0 − 2 i / d \theta_i = 10000^{-2i/d} θi=100002i/d,那么 L L ′ θ i \frac{L}{L'}\theta_i LLθi可以进一步改写为:

L L ′ θ i = L L ′ ( 1000 0 − 2 i / d ) = [ ( L ′ L ) d / 2 i × 10000 ] − 2 i / d \frac{L}{L'}\theta_i = \frac{L}{L'}(10000^{-2i/d}) = [(\frac{L'}{L})^{d/2i} \times 10000]^{-2i/d} LLθi=LL(100002i/d)=[(LL)d/2i×10000]2i/d

由于L’/L 大于 1,d/2i 也大于 1,因此 ( L ′ L ) d / 2 i > 1 (\frac{L'}{L})^{d/2i} > 1 (LL)d/2i>1,相当于扩大了 base。这与其他扩大 base 的做法在本质上是相同的。

代码实现

以HuggingFace 的 transformers 库 models/llama/modeling_llama.py的实现方式为例。

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        
        # ......

        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        t = t / self.scaling_factor
        freqs = torch.outer(t, self.inv_freq)

        # ......

其中,self.scaling_factor 是缩放比 L/L’,inv_freq θ i \theta_i θi

存在问题

但位置线性内插方法有一个缺点,插值的方式会导致相邻位置的差异变小(上图中相邻蓝色点的距离),尤其是原先就在训练范围内的相邻位置,因此需要重新训练。训练的步数不用太多,1000 步左右就能很好地应用到长 context 文本上。

显示全文