问题1:多头注意力的定义和实现
问题2:多头注意力的引入原因和头功能控制
问题3:交叉注意力与自注意力的区别
多头注意力(Multi-Head Attention)将单一的注意力机制分割为多个"头",每个头学习不同的表示子空间。
给定输入序列 \(X \in \mathbb{R}^{n \times d_{model}}\),其中:
多头注意力的数学定义为:
其中第 \(i\) 个头的计算为:
这里:
每个头的维度满足:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性投影层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性投影并分割成多个头
# (batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, d_k)
Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 应用softmax
attn_weights = F.softmax(scores, dim=-1)
# 加权求和
attn_output = torch.matmul(attn_weights, V)
# 拼接多个头
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 最终线性投影
output = self.W_o(attn_output)
return output, attn_weights
所有头共享同一组 K 和 V,仅 Q 是多头:
优势:减少内存消耗和计算量,特别是在推理时显著提升效率。
将 \(h\) 个头分为 \(g\) 组,每组共享一组 K 和 V:
其中 \(\lfloor i/h \rfloor\) 表示头 \(i\) 所属的组索引。
单一注意力机制需要在一个空间中建模所有类型的依赖关系,而多头注意力允许:
假设输入空间的基向量为 \(\{b_1, b_2, ..., b_d\}\),多头注意力通过不同的投影矩阵:
这样每个头可以在其专门的子空间中学习最优的注意力模式。
研究显示不同头确实学习到不同的语言学现象:
观察 \(A_i\) 的模式,可以发现每个头关注的信息类型。
通过逐步移除头来评估每个头的贡献度。
其中 \(\text{sim}(\cdot,\cdot)\) 是相似度函数,最小化该损失可以增加头间的多样性。
其中 \(\alpha_i\) 控制第 \(i\) 个头的重要性。
其中 \(g_i(X)\) 是学习到的门控函数,\(\sigma\) 是sigmoid函数。
交叉注意力(Cross-Attention)用于处理来自不同序列的信息,其核心是使用一个序列生成Query,另一个序列生成Key和Value。
设有两个输入序列:
交叉注意力的计算为:
其中:
所有Q、K、V都来自同一个输入序列 \(X\)。
Q、K、V来自不同序列:
其中 \(n\) 是query序列长度,\(m\) 是key-value序列长度。
# 自注意力
self_attn_output = self_attention(X, X, X)
# 交叉注意力
cross_attn_output = cross_attention(
Q=decoder_hidden, # 来自解码器
K=encoder_output, # 来自编码器
V=encoder_output # 来自编码器
)
多头注意力通过多个子空间学习不同的表示模式
不同的头可以捕捉不同类型的依赖关系
交叉注意力用于连接不同序列,自注意力用于序列内部建模
可以通过正则化和结构化方法控制头的功能