RoPE 位置编码在超出一定的序列长度后,模型生成的 PPL 指标会爆炸,因此直接外推的效果很差。Meta 的研究团队在论文《》中提出了“位置线性内插”(Position Interpolation,PI)方案,来扩展 LLM 的 context length。
关于 RoPE 相对位置编码,可参考我的上一篇博客 。
将预测的长文本位置缩放到训练长度范围以内,具体流程如下:
位置线性内插的核心思想是通过缩放位置索引,使得模型能够处理比预训练时更长的序列,而不损失太多的性能。其数学表达式如下所示:
f ′ ( x , m ) = f ( x , m L L ′ ) f'(x, m) = f(x, \frac{mL}{L'}) f′(x,m)=f(x,L′mL)
其中,x 是 token embedding、m 是位置索引,L’ 是扩展后的序列长度,L 是训练时的序列长度。
g ( x m , x n , m − n ) = R e [ ( W q x m ) ( W k x n ) ∗ e i ( m − n ) θ i ] g(x_m, x_n, m - n) = Re[(W_q x_m) (W_k x_n) * e^{i(m - n)\theta_i}] g(xm,xn,m−n)=Re[(Wqxm)(Wkxn)∗ei(m−n)θi]
根据 RoPE 相对位置编码的数学表达式,加上位置线性内插后,m 和 n 同乘上 L/L’,可以表示为:
g ( x m , x n , m − n ) = R e [ ( W q x m ) ( W k x n ) ∗ e i ( m − n ) L L ′ θ i ] g(x_m, x_n, m - n) = Re[(W_q x_m) (W_k x_n) * e^{i(m - n)\frac{L}{L'}\theta_i}] g(xm,xn,m−n)=Re[(Wqxm)(Wkxn)∗ei(m−n)L′Lθi]
通常, θ i = 1000 0 − 2 i / d \theta_i = 10000^{-2i/d} θi=10000−2i/d,那么 L L ′ θ i \frac{L}{L'}\theta_i L′Lθi可以进一步改写为:
L L ′ θ i = L L ′ ( 1000 0 − 2 i / d ) = [ ( L ′ L ) d / 2 i × 10000 ] − 2 i / d \frac{L}{L'}\theta_i = \frac{L}{L'}(10000^{-2i/d}) = [(\frac{L'}{L})^{d/2i} \times 10000]^{-2i/d} L′Lθi=L′L(10000−2i/d)=[(LL′)d/2i×10000]−2i/d
由于L’/L 大于 1,d/2i 也大于 1,因此 ( L ′ L ) d / 2 i > 1 (\frac{L'}{L})^{d/2i} > 1 (LL′)d/2i>1,相当于扩大了 base。这与其他扩大 base 的做法在本质上是相同的。
以HuggingFace 的 transformers 库 models/llama/modeling_llama.py
的实现方式为例。
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
# ......
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# ......
其中,self.scaling_factor
是缩放比 L/L’,inv_freq
是
θ
i
\theta_i
θi。
但位置线性内插方法有一个缺点,插值的方式会导致相邻位置的差异变小(上图中相邻蓝色点的距离),尤其是原先就在训练范围内的相邻位置,因此需要重新训练。训练的步数不用太多,1000 步左右就能很好地应用到长 context 文本上。