12.04课堂作业:多头注意力机制

2024年12月4日 Transformer

作业题目

问题1:多头注意力的定义和实现

问题2:多头注意力的引入原因和头功能控制

问题3:交叉注意力与自注意力的区别

问题1:多头注意力的定义和实现

1.1 多头注意力的头是如何定义的

多头注意力(Multi-Head Attention)将单一的注意力机制分割为多个"头",每个头学习不同的表示子空间。

输入定义

给定输入序列 \(X \in \mathbb{R}^{n \times d_{model}}\),其中:

  • \(n\) 是序列长度
  • \(d_{model}\) 是模型维度

数学定义

多头注意力的数学定义为:

$$\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \text{head}_2, ..., \text{head}_h)W^O$$

其中第 \(i\) 个头的计算为:

$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

这里:

  • \(h\):头的数量(通常为8或12)
  • \(W_i^Q \in \mathbb{R}^{d_{model} \times d_k}\):第 \(i\) 个头的查询权重矩阵
  • \(W_i^K \in \mathbb{R}^{d_{model} \times d_k}\):第 \(i\) 个头的键权重矩阵
  • \(W_i^V \in \mathbb{R}^{d_{model} \times d_v}\):第 \(i\) 个头的值权重矩阵
  • \(W^O \in \mathbb{R}^{h \times d_v \times d_{model}}\):输出投影矩阵

每个头的维度满足:

$$d_k = d_v = \frac{d_{model}}{h}$$

1.2 具体实现步骤

1

线性投影

$$Q_i = XW_i^Q, \quad K_i = XW_i^K, \quad V_i = XW_i^V$$
2

并行注意力计算

$$\text{head}_i = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_i$$
3

拼接与投影

$$\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

1.3 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性投影层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # 线性投影并分割成多个头
        # (batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, d_k)
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 应用softmax
        attn_weights = F.softmax(scores, dim=-1)
        
        # 加权求和
        attn_output = torch.matmul(attn_weights, V)
        
        # 拼接多个头
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # 最终线性投影
        output = self.W_o(attn_output)
        
        return output, attn_weights

1.4 两种多头注意力的变体

变体1:Multi-Query Attention(MQA)

所有头共享同一组 K 和 V,仅 Q 是多头:

$$\text{head}_i = \text{Attention}(XW_i^Q, XW^K, XW^V)$$

优势:减少内存消耗和计算量,特别是在推理时显著提升效率。

变体2:Grouped Query Attention(GQA)

将 \(h\) 个头分为 \(g\) 组,每组共享一组 K 和 V:

$$\text{head}_i = \text{Attention}(XW_i^Q, XW^{\lfloor i/h \rfloor \cdot K}, XW^{\lfloor i/h \rfloor \cdot V})$$

其中 \(\lfloor i/h \rfloor\) 表示头 \(i\) 所属的组索引。

问题2:多头注意力的引入原因和头功能控制

2.1 为什么要引入多头注意力

2.1.1 表示子空间多样化

单一注意力机制需要在一个空间中建模所有类型的依赖关系,而多头注意力允许:

  • 不同头学习不同模式:如句法依赖、语义关联、长距离依赖等
  • 并行计算多个视角:类似于CNN中不同卷积核学习不同特征
  • 增强表达能力:组合多个头的表示可以得到更丰富的特征

2.1.2 数学解释

假设输入空间的基向量为 \(\{b_1, b_2, ..., b_d\}\),多头注意力通过不同的投影矩阵:

$$W_i^Q, W_i^K, W_i^V \text{ 定义不同的子空间 } S_i = \text{span}(W_i^Q, W_i^K, W_i^V)$$

这样每个头可以在其专门的子空间中学习最优的注意力模式。

2.1.3 实验验证

研究显示不同头确实学习到不同的语言学现象:

  • 头1-2:主要关注句法结构
  • 头3-4:处理语义关联
  • 头5-6:捕捉长距离依赖
  • 头7-8:处理位置信息

2.2 如何验证每个头的功能

2.2.1 可视化分析

1. 注意力权重可视化
$$A_i = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)$$

观察 \(A_i\) 的模式,可以发现每个头关注的信息类型。

2. 特征分析
  • 计算每个头输出的相似度矩阵
  • 分析头之间的信息冗余性
  • 使用t-SNE等降维方法可视化头的表示

2.2.2 消融实验

1. 单头性能测试
$$\text{Performance}_i = f(\text{仅使用head}_i)$$
2. 头组合测试
$$\text{Performance}_{S} = f(\{\text{head}_i | i \in S\})$$
3. 头重要性排序

通过逐步移除头来评估每个头的贡献度。

2.3 训练中控制头的功能

2.3.1 正则化方法

1. 头多样性正则化
$$\mathcal{L}_{div} = \sum_{i \neq j} \text{sim}(\text{head}_i, \text{head}_j)$$

其中 \(\text{sim}(\cdot,\cdot)\) 是相似度函数,最小化该损失可以增加头间的多样性。

2. 头重要性正则化
$$\mathcal{L}_{imp} = \sum_{i} \alpha_i \|A_i - A_{target,i}\|^2$$

其中 \(\alpha_i\) 控制第 \(i\) 个头的重要性。

2.3.2 结构化控制

1. 头专门化架构
  • 为不同任务设计专门的头
  • 使用门控机制控制头的激活
  • 引入层次化的头结构
2. 动态头选择
$$\text{head}_i^{active} = \sigma(g_i(X)) \cdot \text{head}_i$$

其中 \(g_i(X)\) 是学习到的门控函数,\(\sigma\) 是sigmoid函数。

问题3:交叉注意力与自注意力的区别

3.1 交叉注意力的定义

交叉注意力(Cross-Attention)用于处理来自不同序列的信息,其核心是使用一个序列生成Query,另一个序列生成Key和Value。

设有两个输入序列:

  • 源序列:\(X^{src} \in \mathbb{R}^{n \times d_{model}}\)
  • 目标序列:\(X^{tgt} \in \mathbb{R}^{m \times d_{model}}\)

交叉注意力的计算为:

$$\text{CrossAttention}(X^{tgt}, X^{src}, X^{src}) = \text{softmax}\left(\frac{Q^{tgt} (K^{src})^T}{\sqrt{d_k}}\right)V^{src}$$

其中:

  • \(Q^{tgt} = X^{tgt}W^Q\):从目标序列生成Query
  • \(K^{src} = X^{src}W^K\):从源序列生成Key
  • \(V^{src} = X^{src}W^V\):从源序列生成Value

3.2 QKV矩阵的区别

自注意力(Self-Attention)

$$Q = XW^Q, \quad K = XW^K, \quad V = XW^V$$

所有Q、K、V都来自同一个输入序列 \(X\)。

交叉注意力(Cross-Attention)

$$Q = X_{query}W^Q, \quad K = X_{key}W^K, \quad V = X_{value}W^V$$

Q、K、V来自不同序列

  • Query来自目标序列(或解码器序列)
  • Key和Value来自源序列(或编码器序列)

3.3 应用场景区别

自注意力应用

  • Encoder端:理解输入序列内部关系
  • Decoder端:理解已生成序列的上下文
  • 自编码任务:如BERT的MLM任务

交叉注意力应用

  • Encoder-Decoder架构:如Transformer的解码器
  • 机器翻译:源语言到目标语言
  • 问答系统:问题到答案的映射
  • 多模态学习:文本到图像的对应

3.4 计算复杂度比较

自注意力复杂度

  • 时间复杂度:\(O(n^2 \cdot d)\)
  • 空间复杂度:\(O(n^2)\)

交叉注意力复杂度

  • 时间复杂度:\(O(n \cdot m \cdot d)\)
  • 空间复杂度:\(O(n \cdot m)\)

其中 \(n\) 是query序列长度,\(m\) 是key-value序列长度。

3.5 信息流向

自注意力

  • 信息在序列内部流动
  • 每个位置可以访问序列中的所有其他位置
  • 对称的信息交互

交叉注意力

  • 信息从源序列流向目标序列
  • 目标序列的每个位置访问源序列的所有位置
  • 非对称的信息检索

3.6 代码对比

# 自注意力
self_attn_output = self_attention(X, X, X)

# 交叉注意力
cross_attn_output = cross_attention(
    Q=decoder_hidden,  # 来自解码器
    K=encoder_output,  # 来自编码器
    V=encoder_output   # 来自编码器
)

总结

关键要点

多头注意力通过多个子空间学习不同的表示模式

不同的头可以捕捉不同类型的依赖关系

交叉注意力用于连接不同序列,自注意力用于序列内部建模

可以通过正则化和结构化方法控制头的功能