探秘Transformer系列之(12)--- 多头自注意力
探秘Transformer系列之(12)--- 多头自注意力目录
[*]探秘Transformer系列之(12)--- 多头自注意力
[*]0x00 概述
[*]0x01 研究背景
[*]1.1 问题
[*]1.2 根源
[*]1.3 解决方案
[*]0x02 原理
[*]2.1 架构图
[*]偏置
[*]权重矩阵
[*]\(W^O\)矩阵
[*]2.2 设计思路
[*]子空间&分治
[*]ensemble&融合
[*]缓解稀疏
[*]2.3 计算
[*]计算流程
[*]计算强度
[*]2.4 效果
[*]2.5 融合方式
[*]2.6 分析
[*]2.7 优点
[*]0x03 实现
[*]3.1 定义
[*]3.2 运算逻辑
[*]输入
[*]投影
[*]切分数据
[*]逻辑角度
[*]物理角度
[*]小结
[*]调整维度
[*]为每个头计算注意力
[*]单独分组
[*]并行
[*]融合每个头的Z
[*]forward()函数
[*]3.3 调用
[*]编码器
[*]解码器
[*]0x04改进
[*]4.1 MOHSA
[*]4.2 MoH
[*]4.3 DCMHA
[*]研究背景
[*]动机
[*]思路
[*]0xFF 参考
0x00 概述
MHSA(多头自注意力) 是 Transformer 模型的核心模块。Transformer本质上是一个通用的可微计算机,集多种优秀特性于一身。
[*]Transformer 类似消息传递的架构具有通用性(即完整性)和强大功能(即效率),能够涵盖许多现实世界的算法,因此Transformer具备非常强大的表现力(在前向传播中)。
[*]通过反向传播和梯度下降,Transformer可以持续不断的优化。
[*]因为Transformer的计算图是浅而宽的,而且自注意力机制让我们在处理序列数据时,能够并行计算序列中的每个元素,所以Transformer能够更好地映射到我们的高并行计算架构(比如GPU)来进行高效计算。
[*]多头注意力机制通过并行运行多个自注意力层并综合结果,能同时捕捉输入序列在不同子空间的信息,增强了模型的表达能力。这种特性使得Transformer可以更好地理解数据中的复杂模式和语义信息,在自然语言处理、计算机视觉等多领域都能出色应用,泛化能力强。
多头注意力机制就是蛋糕上的樱桃。多头注意力机制的巧妙之处在于,它能够通过并行运行多个具有独特视角的注意力头来同时处理数据,使得模型能够从多个角度分析输入序列,捕捉丰富的特征和依赖关系。类似于一组专家分析复杂问题的各个方面。或者像同时有多个视角在看同一个东西,每个视角都能看到一些不同的细节。下图形象化的解释了多头注意力运行机制,Query、Key和Value 被分为不同的Head,并在每个Head中独立计算自注意力。
0x01 研究背景
1.1 问题
迄今为止,注意力机制看起来很美好,但是也暴露出来了一些缺陷:
比如,模型在编码时,容易会过度的将注意力集中于当前的位置,而忽略了其它位置的信息,从而错过某些重要的依赖关系或特征。用程序化的语言来说,因为Q、K、V都来自输入X,在计算\(QK^T\)时,模型容易关注到自身的位置上,即\(QK^T\)对角线上的激活值会明显比较大,这样会削弱模型关注其它高价值位置上的能力,限制了模型的理解和表达能力。
再比如,注意力机制是使用Q去找相关的K,但是”相关“可以有不同形式和定义,比如一项事物往往有多个方面,应该综合利用各方面的信息/特征,从多个角度进行衡量。比如下面句子中就有字体大小,背景颜色,字体颜色,加粗/下划线/斜线这几个不同的强调维度,需要多方考虑。
另外,人类注意力机制本身就是天然可以同时处理多个方面的信息的。设想你在一个拥挤的公交车上看书,你的大脑能自动关注到书的内容,同时也可以留意周围的环境声,譬如有人叫你的名字或是公交车到站播报声。
而迄今为止,在我们的学习历程中,当前的Transformer注意力机制只是注重事物的单独方面,而非注意多个方面。
1.2 根源
Embedding 才是多头注意力背后的真正内在成因。Embedding 是人类概念的映射,或者说是表达人类概念的途径或者方法。人类的概念是一个及其复杂的系统,因为概念需要有足够的内部复杂度才能应对外部世界的复杂度。比如对于一个词来说,其就有语义逻辑、语法逻辑、上下文逻辑、在全句中位置逻辑、分类逻辑等多种维度。而且,词与词之间的关系还不仅仅限于语义上的分类所导致的定位远近这么简单。一个词所代表的事物与其他词所代表的事物之间能产生内在联系的往往有成百上千上万种之多。
或者说,概念是被配置为能够跨任务工作的向量,是去除非本质信息,保留最确定性的结果。在这种基础上,存储在长期记忆中的单个概念向量可以通过不同的函数进行投影,以用于不同特定领域的任务。每个任务其实可以认为是一个独立的向量空间。比如对于上面的例子,字体和颜色就是两个不同的子空间(低维空间)。
而目前注意力只注重单独某个向量空间,势必导致虽然最终生成的向量可以在该空间上有效将人类概念进行映射,但是无法有效反映外部丰富的世界。因此,我们需要一种可以允许模型在不同的子空间中进行信息选择的机制。
1.3 解决方案
多头注意力就是研究人员给出的解决方案。多头注意力可以理解为高维向量被拆分或者转化为H份低维向量,并在H个低维空间里求解各自的注意力。这样模型就可以从不同角度来分析和理解输入信息,最终输出包含有不同子空间中的编码表示信息,从而增强模型的表达能力。Transformer论文中对于多注意力机制的论述如下。
Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.
多头注意力机制基于自注意力机制基础上进行扩展。在传统的自注意力机制中,你只能使用一组查询(Q)、键(K)和值(V)来计算注意力权重。但是,在多头注意力机制中,你可以使用多组不同的Q、K和V来进行计算。每个注意力头都有自己独立的一组Q、K和V,多组Q、K和V通过独立的线性变换来生成。
不同的Q去查找不同方面的相关性,比如某个Q去捕捉语法依赖,另一个Q去捕捉语义依赖,这样每个注意力头可以关注文本中不同的方面和特征,才能不仅抓住主旨,同时也能理解各个词汇间的关联,进而从多角度捕捉上下文和微妙之处,并行地学习多组自注意力权重。最后,多个注意力头的结果会被拼接在一起,并通过另一个线性变换进行整合,得到最终的输出。多头注意力机制具体如下图所示。其中,D 表示 hidden size,H 表示 Head 个数,L 表示当前是在序列的第 L 个 Token。
https://img2024.cnblogs.com/blog/1850883/202503/1850883-20250308125413019-1260747646.jpg
针对上方句子的例子,我们使用多头注意力就是同时关注字体和颜色等多方面信息,每个注意力头关注不同的表示子空间,这样即可以有效定位网页中强调的内容,也可以灵活选择文字中的各种关系和特征,从而提取更丰富的信息。模型最终的“注意力”实际上是来自不同“表示子空间”的注意力的综合,均衡单一注意力机制可能产生的偏差。
有两个比较确切的例子,可以让大家对多头自注意力有直观的感受。
[*]例子1是从专家的专家角度来看。一个团队合作完成一个软件项目,每个团队成员负责自己擅长的领域。产品经理负责整体项目规划和需求分析;项目经理负责项目把控;前端开发工程师负责与用户界面相关的工作;后端工程师负责服务器逻辑和数据库管理;测试工程师负责项目质量保证。每个团队成员用自己的专业能力独立的对项目付出不同的贡献,最终将各自的成果整合在一起,形成一个完整的软件产品。
[*]例子2更倾向于从合作的角度来看。在橄榄球领域内有一种说法,一场比赛要看四遍,第一遍从总体上粗略看,第二遍从进攻球员角度看,第三遍从防守球员角度看,第四遍则综合之前的理解再总体看一遍。但是这样要看四遍。不如让几个人一起来看一遍比赛,观看过程中,有人负责从从进攻球员角度看,有人负责从防守球员角度看,有人负责总体把握,有人负责看重点球员,有人看教练部署,最终有人将不同的意见和见解整合起来,形成对比赛的完整理解。
0x02 原理
2.1 架构图
多头注意力机制是自注意力机制的变体,多头注意力的架构及公式如下图,h 个 Scale Dot-Product Attention(左)并行组合为 Multi-Head Attention(右)。每个Scaled Dot-Product Attention 结构对输入上下文特征单独做了 一次 上下文信息融合。在此基础之上,我们把多个这样的特征融合操作并联起来,得到多个独立的输出特征张量,再把这些张量联接(concatenate)起来。
上图中,\(W^Q\),\(W^K\),\(W^V\) 这三个矩阵列数可以不同,但是行数都是\(d_{model}\)。\(d_{model}\)为多头注意力机制模块输入与输出张量的通道维度,h为head个数。论文中h=8,因此\(d_k=d_v=d_{model}/h=64\),\(d_{model}=512\)。
偏置
\(W^Q\),\(W^K\),\(W^V\)这三个投影层以及最后的投影层\(W^O\)(Z * Output_weights)可以选择添加或者不添加偏置。
举例:根据LLaMA3源码来看,其没有加入bias。
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False, # 没有偏置
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,# 没有偏置
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,# 没有偏置
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,# 没有偏置
input_is_parallel=True,
init_method=lambda x: x,
)另外,PaLM: Scaling Language Modeling with Pathways 这篇论文里提到,如果对全连接层以及 layer norm 不加偏置项,可以提高训练的稳定性。
No Biases – No biases were used in any of the dense kernels or layer norms. We found this to result in increased training stability for large models.
权重矩阵
如果是Scaled Dot-Product Attention,即单头注意力机制,其要学的参数其实就是三个矩阵 \(W^Q,W^K,W^V\),这个参数量往往不多,且容易是稀疏矩阵。当语义逐渐复杂后,容易因为参数量达到容量上限而造成模型性能不足。
多头就意味着需要把词嵌入分成若干的块,即每个字都转换为若干512/H维度的信息。然后我们将这些块分配到不同的头上,每个头将独立地进行注意力计算。对于每个头得到的Q、K和V,我们都需要分别进行线性变换。计算 Q、K 和 V 的过程还是一样,不过现在执行变换的权重矩阵从一组\((W^Q, W^K, W^V)\)变成了多组:\((W_0^Q, W_0^K, W_0^V)\),\((W_1^Q, W_1^K, W_1^V)\),....\((W_h^Q, W_h^K, W_h^V)\)。通过这些权重矩阵的转换,我们就可以让多组关注不同的上下文的 Q、K 和 V。
多头注意力机制通过更多的权重矩阵来增加了模型的容量,使得模型能够学习到更复杂的表示。在多头注意力中,每个注意力头只关注输入序列中的一个独立子空间,不同头(角度)有不同的关注点,综合多个头可以让模型就能够更全面地理解输入数据。或者这么理解:不同的注意力头可以学习到序列中不同位置之间的不同依赖关系,组合多头注意力可以捕捉多种依赖关系,提供更丰富、更强大的表示。从而使得多头的Q、K、V权重可以在参数量相同的情况提升模型的表达能力。
这些自注意力“头”的关注点并非预设,而是从随机开始,通过处理大量数据并自我学习,自然而然地识别出各种语言特征。它们学习到的一些特征我们能够理解,有些则更加难以捉摸。
\(W^O\)矩阵
上面的操作相当于把一个进程拆分成8个独立的子进程进行操作,每个进程处理原始Embedding的1/n。最终每个进程得到的向量长度是原来embedding长度的1/n。怎样把不同注意力头的输出合起来呢?系统会在d这个维度,通过 Concat 方式把8个子进程的结果串联起来,直接拼接成一个长向量。此时 Concat 后的矩阵实际上并不是有机地融合 8 个“小Embedding”,而只是简单地做了矩阵的前后链接 Concat。这就带来了几个问题:
[*]多个头直接拼接的操作, 相当于默认了每个头或者说每个子空间的重要性是一样的, 在每个子空间里面学习到的相似性的重要度是一样的,即这些头的权重是一样的。然而,各个头的权重事实上肯定不同,如何有机融合?或者说,如何调整不同头之间的权重比例?
[*]自注意力机制模块会接到全连接网络,FFN需要的输入是一个矩阵而不是多个矩阵。而且因为有残差连接的存在,多头注意力机制的输入和输出的维度应该是一样的。
综上,我们需要一个压缩、转换和融合的手段,把 8 个小的语义逻辑子空间有机地整合成一个总体的 Embedding,而且需要把多头注意力的输出恢复为原 Embedding 的维度大小,即512维的向量长度。但是有机融合是个复杂的情况,只凭借人力难以做好。因此研发人员提出来把融合直接做成可学习、可训练的。即设定一个可学习参数,如果它觉得某个头重要, 那干脆让那个头对应的可学习参数大些,输出的矩阵大些,这就类似于增加了对应头的权重。
最终就得到是\(W^O\)方案。利用\(W^O\) 对多头的输出进行压缩和融合来提升特征表征和泛化能力。\(W^{O}\)类似 \(W^{Q}\),\(W^{K}\),$W^{V} \(,也是在模型训练阶段一同训练出来的权重矩阵(右上角 O 意为输出 Output 的意思)。\)W^O$操作前后,维度没有变化。即最终输出的结果和输入的词嵌入形状一样。
2.2 设计思路
我们来反推或者猜测一下Transformer作者的设计思路大致为:以分治+融合的模式对数据进行加工。分治是对数据进行有差别的对待,而融合是做数据融合。
子空间&分治
Embedding
前面提到,Embedding 才是多头背后的真正内在成因。那么让我们再看看这个 Embedding 中的语义逻辑子空间。我们假设有8个注意力头,每个注意头都有自己的可学习权重矩阵\(W_i^Q\), \(W_i^K\)和\(W_i^V\)。$W^{Q} \(,\)W{K}$,$W$ 均是 Transformer 大模型在训练阶段时,通过海量的对照语料训练集训练出来的,他们是专门用来拆解每个 token 在 Embedding 空间中的逻辑细分子空间用的。
通过这些权重矩阵可以把原始高维向量分解成 8 个细分的 Embedding 向量,每个向量对应到一个细分语义逻辑子空间(语义逻辑、语法逻辑、上下文逻辑、分类逻辑等)。实际上便是把 Attention 机制分割在 Embedding 中的不同细分逻辑子空间中来运作了。每个注意力头互相独立的关注到不同的子空间上下文,同时考虑诸多问题,从而获得更丰富的特征信息。
特征提取
Transformer的多头注意力应该也借鉴了CNN中同一卷积层内使用多个卷积核的思想。CNN中使用了不同的卷积核来关注图像中的不同特征,学习不同的信息。然后CNN中逐通道卷积最后沿着通道求和做特征融合。
Transformer的角色定位是特征抽取器或者万能函数逼近器。我们期望捕捉更多的模式,从而利于下游多样的任务微调时,一旦这类模式有用,就可以激活出来让下游任务可以学习到。所以Transformer使用多头对一个向量切分不同的维度来捕捉不同的模式,让模型可能从多种维度去理解输入句子的含义。单个概念向量可以通过不同的函数进行投影,以用于不同特定领域的任务。然后也会接着一个特征融合过程。映射到不同子空间其实就是在模仿卷积神经网络以支持多通道模式的输出。
ensemble&融合
上面重点说的是将输入切分,然后提取不同子空间的信息。接下来我们从另一个方面来解释,多头的核心思想就是ensemble。
大量学术论文证明,很难只依靠单个头就可以同时捕捉到语法/句法/词法信息,因此需要多头。但是多头中每个头的功能不同,有的头可能识别不到啥信息,有的头可能主要识别位置信息,有的头可能主要识别语法信息,有的头主要识别词法信息。multi-head的作用就是为了保证这些pattern都能够被抽取出来。
我们可以把MHA的多个attention计算视为多个独立的小模型,每个head就像是一个弱分类器,最终整体的concat计算相当于把来自多个小模型的结果进行了融合,从而让最后得到的embedding关注多方面信息。而且,单头容易只关注自身的注意力权重,多头(需要让其有一定的头的基数)无疑是通过多次投票降低这种概率,这样效果比较好也是比较符合直觉的。做个比喻来说,这就好像是八个有不同阅读习惯的翻译家一同翻译同一个句子,他们每个人可能翻译时阅读顺序和关注点都有所不同,综合他们八个人的意见,最终得出来的翻译结果可能会更加准确。
缓解稀疏
通过观察大量样本的attention矩阵我们发现,其实几乎每一个token在全句中的注意力都是稀疏的,即每个token只关注非常有限个其他token,其余注意力基本可以看成是0(softmax无法严格为0)。
稀疏就意味着我们用较小的矩阵就可以来合较大的稀疏矩阵,其效果差不多,但是计算量却小很多。因此就不如把Q、K和V切分成多个小段,计算多次注意力矩阵,再再以某种方式整合,这样一来计算量其实跟直接 算单个注意力差不多,但这样模型融合的效果应该至少不差于单个注意力,甚至可能更好,因此有了多头注意力。
2.3 计算
计算流程
多头注意力的计算流程就是把高维向量切分为若干份低维向量,在若干低维空间内分别求解各自的Scaled Dot-Product Attention(点积自注意力)。总体流程分为:切分,计算,拼接,融合四部分,这里涉及很多步骤和矩阵运算,我们用一张大图把整个过程表示出来。
[*]输入依然是原始的Q,K 和 V。
[*]切分。每个注意头都有自己的可学习权重矩阵\(W_i^Q\), \(W_i^K\)和\(W_i^V\)。输入的Q、K和V经过这些权重矩阵进行多个线性变换后得到 N 组Query,Key 和 Value。这些组Q、K和V可以理解为把输入的高维向量线性投影到比较低的维度上。每个新形成的Q在本质上都要求不同类型的相关信息,从而允许注意力模型在上下文向量计算中引入更多信息。此处对于下图的标号1。
[*]计算。每个头都使用 Self-Attention 计算得到 N 个向量。每个头可以专注学习输入的不同部分,从而使模型能够关注更多的信息。此处对于下图的标号2。
[*]拼接。我们的目标是创建一个单一的上下文向量作为注意力模型的输出。因此,由单个注意头产生的上下文向量被拼接为一个向量。此处对于下图标号3。
[*]融合。使用权重矩阵\(W^O\)以确保生成的上下文向量恢复为原 Embedding 的维度大小。这即是降维操作,也是融合操作。此处对于下图的标号4。
计算强度
我们以下图为基础来思考计算强度,D 表示 hidden size,H 表示 Head 个数,L 表示当前是在序列的第 L 个 Token。
[*]当 Batch Size 为 1 时,图中红色、紫色、蓝色虚线框处的矩阵乘法全部为矩阵乘向量,是 Memory Bound(内存受限操作),算术强度不到 1。
[*]当 Batch Size 大于 1 时(比如 Continuous Batching):
[*]红色和蓝色虚线框部分:因为是权重乘以激活,所以不同的请求之间可以共享 Weight。这里变成矩阵乘矩阵,并且 Batch Size 越大,算术强度越大,也就越趋近于 Compute Bound(FFN 层也类似)。
[*]紫色虚线框部分:这里 Q、K 和 V 的 Attention 计算,是激活乘以激活,所以不同的请求之间没有任何相关性。即使 Batching,这里也是 Batched 矩阵乘向量,并且因为序列长度可能不同,这里不同请求的矩阵乘向量是不规则的。也就是说,这里算术强度始终不到 1,是明显的 Memory Bound。
从上可以看出,通过 Continuous Batching 可以很好的将 Memory Bound 问题转变为 Compute Bound,但 Q、K 和 V 的 Attention 计算的算术强度却始终小于 1。Sequence Length 越长,这里的计算量就越不可忽略,因为其属于系统的短板处。
2.4 效果
Transformer论文末尾给出了多头注意力机制中两个头的attention可视化结果,如下所示。图中,线条越粗表示attention的权重越大,可以看出,两个头关注的地方不一样,绿色图说明该头更关注全局信息,红色图说明该头更关注局部信息。
论文“What Does BERT Look At? An Analysis of BERT’s Attention”也给出了不同注意力头的示例。线条的粗细表示注意力权重的强度(一些注意力权重太低,以至于看不见)。
2.5 融合方式
vanilla Transformer中,对于不同的注意力采取的整合方式是直接拼接。论文"Multi-Head Attention: Collaborate Instead of Concatenate“提出了其它整合方式。该论文发现所有注意力头之间捕捉的信息肯定是存在冗余的,头与头之间存在较多的通用信息。拼接后的 \(
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
页:
[1]