多查询注意(MQA)只使用单个键值头,可大大加快解码器的推理速度。然而,MQA 可能会导致质量下降,而且仅仅为了加快推理速度而训练一个单独的模型可能并不可取。
提出了一种将现有的多头语言模型检查点向上训练成带有 MQA 的模型的方法,只需使用原始预训练计算量的 5%;
引入分组查询注意(GQA),这是一种多查询注意的广义化,它使用中间(多于一个,少于查询头数)键值头的数量。
我们的研究表明,经过向上训练的 GQA 可以达到接近多查询头注意力的质量,而且速度与 MQA 相当。
自回归解码器推理是 Transformer 模型的一个严重瓶颈,这是因为在每个解码步骤中加载解码器权重和所有注意键与值会造成内存带宽开销(Shazeer,2019;Pope 等人,2022;de Jong 等人,2022)。通过多查询注意(Shazeer,2019 年),加载键和值所产生的内存带宽可以大幅降低,多查询注意使用多个查询头,而键和值头则使用单个。
这项工作包含两项贡献,旨在加快大型语言模型的推理速度。首先,我们对南加州大学做出了同等贡献。谷歌研究中心的工作表明,使用多头注意力(MHA)的语言模型检查点可以通过向上训练(Komatsuzaki 等人,2022 年)来使用 MQA,只需原始训练计算量的一小部分。这为获得快速多查询和高质量 MHA 检查点提供了一种经济有效的方法。其次,我们提出了分组查询注意(GQA),它是多查询头注意和多查询注意之间的一个插值,每个查询头子组只有一个键和值头。我们的研究表明,经过向上训练的 GQA 可以达到接近多头注意力的质量,同时速度几乎与多查询注意力一样快。
从多头模型生成多查询模型分为两个步骤:首先是转换检查点,其次是额外的预训练,以使模型适应新的结构。图 1 显示了将多头检查点转换为多查询检查点的过程。键头和值头的投影矩阵平均汇集成单一投影矩阵,我们发现这比选择单一键头和值头或从头开始随机初始化新的键头和值头效果更好。
这里相当于把 Look Up 得到的 Keys 向量进行 Mean Pooling 转换为 1 个 Keys 即单头。然后,在相同的预训练配方上,对转换后的检查点进行原始训练步骤的一小部分 α 的预训练。
分组查询注意将查询头分为 G 组,每组共享一个键头和值头。GQA-G 指的是有 G 个组的分组查询。GQA-1 只有一个组,因此也只有一个键头和值头,相当于 MQA,而 GQA-H 的组数等于查询头数,相当于 MHA。图 2 显示了分组查询注意和多头/多查询注意的比较。在将多头检查点转换为 GQA 检查点时,我们通过平均池化该组中的所有原始头来构建每个组的键和值头。
中间组的数量会导致插值模型的质量比 MQA 高,但速度比 MHA 快,正如我们将展示的那样,这代表了一种有利的权衡。从 MHA 到 MQA,可以将 H 个键和值头减少到单个键和值头,从而减少键值缓存的大小,并因此将需要加载的数据量减少 H 倍。GQA 可以让我们在模型规模增大时,保持相同的带宽和容量下降比例。此外,由于 KV 缓存会随着模型维度的增加而增加,而模型 FLOP 和参数则会随着模型维度的平方而增加,因此大型模型因注意力而产生的内存带宽开销相对较小。最后,大型模型的标准分片是按模型分区的数量复制单键和值头(Pope 等,2022 年);GQA 消除了这种分区带来的浪费。因此,我们希望 GQA 能为大型模型提供特别好的权衡。我们注意到,GQA 并未应用于编码器自注意层;编码器表示是并行计算的,因此内存带宽通常不是主要瓶颈。
Configurations
所有模型都基于 T5.1.1 架构(Raffel 等人,2020 年),使用 JAX(Bradbury 等人,2018 年)、Flax(Heek 等人,2020 年)和 Flaxformer1 实现。在我们的主要实验中,我们考虑了具有多头注意力的 T5 Large 和 XXL,以及具有多查询和分组查询注意力的 T5 XXL 的升级训练版本。我们使用 Adafactor 优化器,其超参数和学习率安排与 T5 相同(Raffel 等人,2020 年)。我们将 MQA 和 GQA 应用于解码器自注意和交叉注意,但不包括编码器自注意。
Uptraining
预训练。预训练模型从公开的 T5.1.1 检查点初始化。键头和值头被平均池化到适当的 MQA 或 GQA 结构中,然后使用原始预训练设置和数据集(Raffel et al.) 在 α = 0.05 的情况下,训练耗时约 600 TPUv3 芯片 / 日。
Data
我们在摘要数据集 CNN/每日邮报(Nallapati 等人,2016 年)、arXiv 和 PubMed(Cohan 等人,2018 年)、MediaSum(Zhu 等人,2021 年)和 Multi-News(Fabbri 等人,2019 年);翻译数据集 WMT 2014 English-to-German;以及问题解答数据集 TriviaQA(Joshi 等人,2017 年)上进行了评估。我们没有对 GLUE(Wang 等人,2019 年)等流行的分类基准进行评估,因为自回归推理不太适用于这些任务。
Fine-tuning
为了进行微调,我们在所有任务中使用 0.001 的恒定学习率、128 的批量大小和 0.1 的 Dropout Ratio。CNN/Daily Mail 和 WMT 的输入长度为 512,输出长度为 256。其他摘要数据集的输入长度为 2048,输出长度为 512。最后,TriviaQA 使用输入长度 2048 和输出长度 32。我们训练直到收敛,然后选择 dev 性能最高的检查点。我们使用贪婪解码进行推理。
Timing
我们报告了 xprof(谷歌,2020 年)测量的每个 TPUv4 芯片每个样本的时间。在时序实验中,我们使用 8 个 TPU,每个 TPU 最多可容纳 32 个批次,并对每个模型分别进行并行化优化。
图 3 显示了 MHA T5-Large 和 T5-XXL,以及上训练比例为 α = 0.05 的 MQA 和 GQA-8 XXL 模型在所有数据集上的平均性能与平均推理时间的函数关系。我们发现,相对于 MHA 模型,较大的上训练 MQA 模型提供了有利的权衡,与 MHA-Large 模型相比,其推理质量更高、速度更快。此外,GQA 还能显著提高质量,性能接近 MHA-XXL,速度接近 MQA。表 1 包含了所有数据集的全部结果。
与 MHA-Large 具有更高的质量和更快的速度的 MHA 相比,Uptrained MQA 产生了有利的权衡,GQA 实现了更好的性能,速度增益相似,质量与 MHA-XXL 相当。所有任务的平均性能作为 T5-Large 和 T5-XXL 每个样本的平均推理时间的函数,多头注意力为 5% 的预训练 T5-XXL,MQA 和 GQA-8 注意力。
本节通过实验来研究不同建模选择的效果。我们对具有代表性的任务子样本进行了性能评估: CNN/每日邮报(短篇摘要)、MultiNews(长篇摘要)和 TriviaQA(问题解答)。
Checkpoint conversion
图 4 比较了不同方法在检查点转换方面的性能。平均池化似乎效果最好,然后选择一个头部,然后随机初始化。直观地说,结果按从预训练模型中保留信息的程度排序。
T5-Large 的不同检查点转换方法的性能比较,比例为 α = 0.05。 “平均”均值池键和值头,“First”选择第一个头,“Random”从头开始初始化头。
Uptraining steps
图 5 显示了使用 MQA 和 GQA 的 T5 XXL 的性能随向上训练比例的变化情况。首先,我们注意到 GQA 在转换后已经达到了合理的性能,而 MQA 则需要向上训练才能发挥作用。MQA 和 GQA 都能从 5% 的上行训练中获益,而从 10% 的上行训练中收益递减。
Number of groups
图 6 展示了 GQA 组的数量对推理速度的影响。对于较大的模型,来自 KV 缓存的内存带宽开销限制较小(Shazeer,2019 年),而键值大小的减少则由于头部数量的增加而更加明显。因此,从 MQA 开始增加组数最初只会导致适度的减速,随着我们更接近 MHA,成本会越来越高。我们选择了 8 组作为有利的中间点。
这项工作的重点是通过减少加载键和值带来的内存带宽开销(Williams 等人,2009 年),在解码器质量和推理时间之间实现更好的权衡。Shazeer (2019) 首次提出通过多查询关注来减少这种开销。后续工作表明,多查询关注对长输入特别有帮助(Pope 等人,2022 年;de Jong 等人,2022 年)。拉贝(2023 年)独立开发了 GQA,并公开实施。其他研究还探讨了如何分组注意力头以提高计算效率(Park 等人,2020 年;Luo 等人,2022 年;Ni 等人,2023 年),但没有特别关注键值头,因为键值头决定了内存带宽开销。
已经提出了许多其他方法来减少键和值的内存带宽开销,以及参数。Flash attention (Dao et al., 2022) 构建了注意力计算,以避免具体化二次注意力分数,减少内存和加速训练。量化(Dettmers 等人,2022;Frantar 等人,2022)通过降低精度来降低权重和激活的大小,包括键和值。模型蒸馏 (Hinton et al., 2015; Gou et al., 2021) 相反,使用从较大模型生成的数据来微调较小的模型,以给定精度降低模型大小。层稀疏交叉注意(de Jong et al., 2022)消除了大多数交叉注意层,这对较长的输入构成了主要费用。推测采样 (Chen et al., 2023; Levianathan et al., 2022) 通过提出多个具有较小模型的令牌来改进内存带宽瓶颈,然后由更大的模型并行评分。
最后,我们提出的上采样过程受到 Komatszaki 等人的启发。 (2022),它将标准 T5 检查点升级为稀疏激活的专家混合模型。
语言模型的推理成本很高,主要是由于加载键和值的内存带宽开销。多查询注意力以降低模型容量和质量为代价减少了这种开销。我们建议将多头注意力模型转换为具有少量原始预训练计算的多查询模型。此外,我们引入了分组查询注意力,这是一种多查询和多头注意力的插值,它以与多查询注意力相当的速度实现了接近多头的质量。
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1: # MHA
return x
return ( # MQA / GQA
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
目的:这个函数用于在 n_kv_heads 小于 n_heads 时重复键/值头。n_rep 代表重复的次数,如果 n_rep 为 1,则直接返回 x,否则,沿着新维度扩展张量并重新调整形状以重复头。n_rep = 3 效果如下:
输入张量形状 => 输出张量形状 Bsz x SeqLen x KV_heads x HeadDim
torch.Size([2, 3, 2, 4]) => torch.Size([2, 3, 6, 4])
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size # 此处
self.n_rep = self.n_local_heads // self.n_local_kv_heads # 此处 几个组
self.head_dim = args.dim // args.n_heads
n_kv_heads: 如果 n_kv_heads 为 None 则退化为 MHA,否则根据 n_kv_heads 参数定义
model_parallel_size: 指定模型并行的规模,即模型被分割为多少个部分。local_heads 和 local_kv_heads 这些参数会受到该参数的影响,因为它们决定了每个设备上实际计算的头数。
n_rep: 代表 group 的数量,因为每一个分组共用向量,所以 repeak_kv 会就形状恢复
head_dim: 通过 embedding_dim 和 heads 计算
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim, # 初始化为单个组内的一份
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim, # 初始化为单个组内的一份
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
传统的模型 QKV 是同 Size 的,但是由于引入了 Group 的概念,Q 尺寸不变为 n_heads * head_dim,但是 KV 变成 n_kv_heads * head_dim,后续通过 ✖️ n_rep 即 Group 数复制为与 Q 同尺寸,从而继续执行传统的 QKV 计算,因此 GQA 在 Attention 的计算方式上还是一样的,区别是减少了读取 KV 的 IO,同时减少了 KV-Cahce 的换存量。
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
这里是为 KV Cache 预留的空间,BSZ x SeqLen x KV_HEADS x HeadDim 也是我们计算 KV_Cache 缓存量的方法。
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads # 单个组扩展为完整head
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
下面拆解下前向传播的过程:
通过 QKV 矩阵获取对应的矩阵参数
通过 wq、wk、wv 获取 QKV 并通过 view 转换为 mulit_head 的形式,最后对 qk 应用 ROPE 旋转位置编码,引入 token 位置信息:
更新 KV-Cache
首先通过 self.cache 获取原始的 cache,再将本次 lookup 得到的心得 k/v 添加值 cache 中用于下一次 token 生成。
Repeat KV
按照 Grouped-query 的图示维度是无法直接矩阵乘法的,因为维度不匹配,所以需要把对应的 keys、values 进行 repeak 操作,这也是最上面介绍的 repeak_kv 的作用。
QKV Attention
基础的 Attention 计算操作,包括 QK Matmal 计算、Scale 进行规划化、Softmax 归一化以及最终归一化权重与 Values 加权求和得到最终的输出,维度为 BSZ x SeqLen x EmbDim
前两天分享了 Gemma-2 的技术报告,其中有同学问到了关于 QKV 维度的问题:
下面我们结合 Gemma-1、Gemma-2 的模型结构再细化一下 GQA 的概念:
从图中可以看到 Gemma-1 的 QKV 都是 [3072x4096] 的,因此通过 wq wk wv 转换后,其可以直接 matmul,因为尺寸是一致的,但是 Gemma-2 的 q 是 [3584x4096],而 k,v 是 [3584 x 2048] 的,这是因为 Num heads = 16,而 Num KV heads = 8,根据源码:
self.n_rep = self.n_local_heads // self.n_local_kv_heads # 此处 几个组
我们可知 Gemma-2 的 n_rep = 16 / 8 = 2,即 Gemma-2 GQA 选取的 Group 数为 2,与 Gemma-2 报告中的匹配,结合报告中的实验可知,Gemma-2 在这里选择了极致的效率,即最小的 Group 换取最快的性能: