Transformer 架构详解:自注意力机制与多头注意力

Transformer 架构详解:自注意力机制与多头注意力

🎤 欢迎来到今天的讲座!

大家好,欢迎来到今天的讲座!今天我们要聊的是近年来在自然语言处理(NLP)领域风靡一时的 Transformer 架构。尤其是其中最核心的部分——自注意力机制(Self-Attention)多头注意力(Multi-Head Attention)。如果你对这些概念还不是很熟悉,别担心,我会用尽可能轻松诙谐的语言,结合代码和表格,带你一步步理解它们。

📚 什么是 Transformer?

首先,让我们简单回顾一下 Transformer 是什么。Transformer 是一种基于注意力机制的神经网络架构,最早由 Google 在 2017 年的论文《Attention is All You Need》中提出。它彻底改变了传统的 NLP 模型设计,摒弃了 RNN 和 LSTM 等依赖于序列顺序的模型,转而使用并行化的方式处理输入数据。这使得 Transformer 在处理长文本时更加高效,并且在各种 NLP 任务中取得了前所未有的成功。

🔍 自注意力机制(Self-Attention)

1. 为什么需要自注意力?

在传统的 RNN 或 LSTM 模型中,输入序列是按顺序逐个处理的,这意味着每个时间步只能看到之前的上下文信息。然而,对于一些长句子或复杂语境的任务,这种顺序处理方式可能会导致信息丢失或传播不充分。为了解决这个问题,自注意力机制 应运而生。

自注意力机制的核心思想是:让每个词都关注整个句子中的其他词,从而捕获更丰富的上下文信息。想象一下,你正在读一句话,某个词可能不仅仅依赖于它前面的词,还可能与后面的词有重要的关联。自注意力机制就是为了让模型能够“同时”看到所有词的关系。

2. 自注意力的工作原理

自注意力机制的具体实现可以分为以下几个步骤:

  • Query、Key 和 Value:在自注意力机制中,每个词都会被映射成三个向量:Query (Q)Key (K)Value (V)。这三个向量的作用分别是:

    • Query (Q):表示当前词要“查询”的内容。
    • Key (K):表示其他词的“特征”。
    • Value (V):表示其他词的“值”。

    这些向量是通过线性变换从输入嵌入(Embedding)得到的。假设输入是一个形状为 (seq_len, d_model) 的矩阵,其中 seq_len 是序列长度,d_model 是模型的维度。那么,Query、Key 和 Value 的计算公式如下:

    Q = X @ W_Q
    K = X @ W_K
    V = X @ W_V

    其中,X 是输入嵌入矩阵,W_QW_KW_V 是可学习的权重矩阵。

  • 计算注意力分数:接下来,我们需要计算每个词与其他词之间的相似度。这个相似度通常使用 点积 来衡量,即 QK 的点积。为了防止数值过大,我们还会对点积结果进行 缩放,通常除以 sqrt(d_k),其中 d_kK 的维度。最后,我们使用 softmax 函数将这些分数归一化为概率分布。

    scores = Q @ K.T / math.sqrt(d_k)
    attention_weights = softmax(scores)
  • 加权求和:有了注意力权重后,我们就可以根据这些权重对 V 进行加权求和,得到最终的输出。这个过程可以看作是每个词根据它与其他词的相似度,选择性地“吸收”其他词的信息。

    output = attention_weights @ V

3. 代码示例

为了更好地理解自注意力机制,我们可以通过一个简单的 PyTorch 代码来实现它:

import torch
import torch.nn as nn
import math

class SelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(SelfAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 定义 Query, Key, Value 的线性变换层
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()

        # 计算 Query, Key, Value
        Q = self.W_Q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attention_weights = torch.softmax(scores, dim=-1)

        # 加权求和
        output = torch.matmul(attention_weights, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        return output

🧠 多头注意力(Multi-Head Attention)

1. 为什么需要多头注意力?

虽然自注意力机制已经非常强大,但它仍然有一个局限性:它只能捕捉到一种类型的“注意力”。换句话说,它只能在一个固定的视角下看待词与词之间的关系。然而,现实中的语言是非常复杂的,不同词之间可能存在多种不同的关系。例如,某个词可能在语法上与另一个词有关联,但在语义上却与第三个词更相关。

为了解决这个问题,多头注意力 被引入。它的核心思想是:使用多个独立的自注意力机制,每个机制负责捕捉不同类型的关系。然后,我们将这些不同“头”的输出拼接在一起,形成一个更丰富的表示。

2. 多头注意力的工作原理

多头注意力的实现非常简单,它实际上就是在自注意力的基础上做了两件事:

  1. 并行计算多个自注意力:我们可以在同一个输入上并行计算多个自注意力机制,每个机制都有自己的 W_QW_KW_V 权重矩阵。这样,每个“头”都可以捕捉到不同的信息。

  2. 拼接并融合:将多个“头”的输出拼接在一起,然后通过一个线性变换将其投影回原始的维度。这个过程可以用以下公式表示:

    multi_head_output = Concat(head_1, head_2, ..., head_h) @ W_O

    其中,head_i 表示第 i 个自注意力机制的输出,W_O 是一个可学习的权重矩阵,用于将拼接后的输出投影回原始维度。

3. 代码示例

我们可以在之前的 SelfAttention 类基础上,扩展出一个多头注意力的实现:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 定义多头自注意力
        self.attention = SelfAttention(d_model, num_heads)

        # 定义输出的线性变换层
        self.linear = nn.Linear(d_model, d_model)

    def forward(self, x):
        # 计算多头自注意力
        multi_head_output = self.attention(x)

        # 线性变换
        output = self.linear(multi_head_output)

        return output

📊 总结与表格

为了帮助大家更好地理解自注意力和多头注意力的区别,我们可以用一个表格来对比它们:

特性 自注意力机制 多头注意力机制
目的 捕捉词与词之间的单一类型关系 捕捉词与词之间的多种类型关系
计算方式 单个 Q, K, V 多个独立的 Q, K, V
输出维度 与输入维度相同 与输入维度相同
优点 简单有效,能够捕捉全局依赖关系 更加灵活,能够捕捉多种类型的依赖关系
缺点 只能捕捉一种类型的依赖关系 参数量较大,计算成本较高

🎉 结语

好了,今天的讲座就到这里啦!我们详细介绍了 Transformer 架构中的 自注意力机制多头注意力机制,并通过代码示例展示了它们的实现过程。希望你能通过这次讲座对这两个概念有更深入的理解。如果你有任何问题,欢迎随时提问!

下次见!🌟

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注