云计算百科
云计算领域专业知识百科平台

Self-Attention原理和实现代码(Pytorch实现)

文章目录

  • 一、Self-Attention原理
  • 二、Self-Attention代码实现
    • 1.单头注意力
      • 单头注意力的代码实现
    • 2.多头注意力
      • 多头注意力的计算流程
      • 多头注意力的代码实现
  • 总结

一、Self-Attention原理

Self-Attention(自注意力)机制是Transformer架构中的核心组件,主要用于捕捉序列中不同位置元素之间的依赖关系。其核心思想是通过计算序列中每个元素与其他元素之间的相关性(注意力权重),然后根据这些权重对序列信息进行加权聚合。计算过程主要分为以下几步:

  • 输入表示: 对于输入序列

    X

    X

    X中的每个元素,通过三个不同的权重矩阵

    W

    Q

    W^Q

    WQ

    W

    K

    W^K

    WK

    W

    V

    W^V

    WV,分别生成:

    查询向量(Query):

    Q

    =

    X

    W

    Q

    Q = X \\cdot W^Q

    Q=XWQ 键向量(Key):

    K

    =

    X

    W

    K

    K = X \\cdot W^K

    K=XWK 值向量(Value):

    V

    =

    X

    W

    V

    V = X \\cdot W^V

    V=XWV 其中,Q、K、V 承担不同的角色。 Q (Query): 查询向量,表示当前需要关注的内容。例如在机器翻译中,解码器当前位置的隐状态作为 Query,用于查询源语言的相关信息。 K (Key): 键向量,表示待匹配的索引。Key 与 Value 关联,用于计算与 Query 的相似度。 V (Value): 值向量,存储实际需要提取的信息。相似度计算后,Value 会按权重聚合生成输出。

  • 计算注意力分数: 通过计算

    Q

    Q

    Q

    K

    K

    K的点积,得到元素间的相似度分数:

    Scores

    =

    Q

    K

    \\text{Scores} = Q \\cdot K^{\\top}

    Scores=QK 为了稳定梯度,分数会除以缩放因子

    d

    k

    \\sqrt{d_k}

    dk

    d

    k

    d_k

    dk

    K

    K

    K的维度)。

  • 生成注意力权重: 对分数应用Softmax函数,得到归一化的注意力权重:

    AttentionWeight

    =

    Softmax

    (

    Q

    K

    d

    k

    )

    \\text{AttentionWeight} = \\text{Softmax}\\left(\\frac{Q \\cdot K^{\\top}}{\\sqrt{d_k}}\\right)

    AttentionWeight=Softmax(dk

    QK)

  • 加权聚合输出: 用注意力权重对

    V

    V

    V加权求和,得到最终输出:

    Output

    =

    AttentionWeight

    V

    \\text{Output} = \\text{AttentionWeight} \\cdot V

    Output=AttentionWeightV

  • 综上所述,Self-Attention的完整公式如下:

    Attention(Q,K,V)

    =

    Softmax

    (

    Q

    K

    d

    k

    )

    V

    \\text{Attention(Q,K,V)} = \\text{Softmax}\\left(\\frac{Q \\cdot K^{\\top}}{\\sqrt{d_k}}\\right) \\cdot V

    Attention(Q,K,V)=Softmax(dk

    QK)V 核心作用: Self-Attention 让模型能够动态关注输入序列的不同部分(例如句子中与当前词相关的其他词),从而更好地理解上下文依赖关系。

    二、Self-Attention代码实现

    1.单头注意力

    单头注意力的代码实现

    代码如下(示例):

    class Attention(nn.Module):
    def __init__(self, d_model, head_size, context_length, dropout=0.1):
    '''
    d_model为输入序列的语义维度
    head_size为一个注意力头在语义维度占据的大小,head_size = d_model/num_heads
    context_len为输入序列的长度
    '''

    super().__init__()
    self.head_size = head_size
    self.Wq = nn.Linear(d_model, head_size, bias=False)
    self.Wk = nn.Linear(d_model, head_size, bias=False)
    self.Wv = nn.Linear(d_model, head_size, bias=False)
    self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
    self.dropout = nn.Dropout(dropout)

    def forward(self, x):
    B, T, C = x.shape
    q = self.Wq(x)
    k = self.Wk(x)
    v = self.Wv(x)
    weights = (q @ k.transpose(2, 1)) / math.sqrt(self.head_size)
    weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
    weights = F.softmax(weights, dim=1)

    return weights @ v

    2.多头注意力

    多头注意力的核心思想:使用多组(

    h

    h

    h 个头)不同的查询

    Q

    Q

    Q、键

    K

    K

    K、值

    V

    V

    V 投影(线性变换),让模型能够并行地从不同的表示子空间(子空间维度

    d

    k

    d_k

    dk,

    d

    v

    d_v

    dv,

    d

    m

    o

    d

    e

    l

    /

    h

    d_{model}/h

    dmodel/h)学习信息。每个头学习不同的关注模式。

    多头注意力的计算流程

    假设: 输入维度:

    d

    m

    o

    d

    e

    l

    d_{model}

    dmodel(例如 512) 头数:

    h

    h

    h(例如 8) 每个头的维度:

    d

    k

    =

    d

    v

    =

    d

    m

    o

    d

    e

    l

    /

    h

    d_k = d_v = d_{model} / h

    dk=dv=dmodel/h(例如 512 / 8 = 64) 步骤:

  • 线性投影(生成 h 组 Q, K, V): 对原始的查询

    Q

    Q

    Q、键

    K

    K

    K、值

    V

    V

    V(维度均为

    d

    m

    o

    d

    e

    l

    d_{model}

    dmodel)分别应用

    h

    h

    h 组不同的线性变换(权重矩阵)。 得到

    h

    h

    h 组投影后的查询

    Q

    i

    Q_i

    Qi、键

    K

    i

    K_i

    Ki、值

    V

    i

    V_i

    Vi,每组维度为

    d

    k

    d_k

    dk,

    d

    k

    d_k

    dk,

    d

    v

    d_v

    dv(通常

    d

    k

    =

    d

    v

    =

    d

    m

    o

    d

    e

    l

    /

    h

    d_k = d_v = d_{model}/h

    dk=dv=dmodel/h)。

    Q

    i

    =

    Q

    W

    i

    Q

    ,

    K

    i

    =

    K

    W

    i

    K

    ,

    V

    i

    =

    V

    W

    i

    V

    for 

    i

    =

    1

    ,

    .

    .

    .

    ,

    h

    Q_i = Q W_i^Q, \\quad K_i = K W_i^K, \\quad V_i = V W_i^V \\quad \\text{for } i = 1, …, h

    Qi=QWiQ,Ki=KWiK,Vi=VWiVfor i=1,,h 其中

    W

    i

    Q

    ,

    W

    i

    K

    ,

    W

    i

    V

    W_i^Q, W_i^K, W_i^V

    WiQ,WiK,WiV 是可学习的投影矩阵。

  • 并行计算缩放点积注意力: 对每一组投影后的

    Q

    i

    ,

    K

    i

    ,

    V

    i

    Q_i, K_i, V_i

    Qi,Ki,Vi,独立计算缩放点积注意力:

    head

    i

    =

    Attention

    (

    Q

    i

    ,

    K

    i

    ,

    V

    i

    )

    =

    softmax

    (

    Q

    i

    K

    i

    T

    d

    k

    )

    V

    i

    \\text{head}_i = \\text{Attention}(Q_i, K_i, V_i) = \\text{softmax}(\\frac{Q_i K_i^T}{\\sqrt{d_k}}) V_i

    headi=Attention(Qi,Ki,Vi)=softmax(dk

    QiKiT)Vi

  • 拼接多头输出: 将

    h

    h

    h 个注意力头计算的结果

    head

    1

    ,

    head

    2

    ,

    .

    .

    .

    ,

    head

    h

    \\text{head}_1, \\text{head}_2, …, \\text{head}h

    head1,head2,,headh(每个维度为

    d

    v

    d_v

    dv)拼接起来,得到一个维度为

    h

    ×

    d

    v

    =

    d

    m

    o

    d

    e

    l

    h \\times d_v = d{model}

    h×dv=dmodel 的向量。

    MultiHead

    (

    Q

    ,

    K

    ,

    V

    )

    =

    Concat

    (

    head

    1

    ,

    head

    2

    ,

    .

    .

    .

    ,

    head

    h

    )

    \\text{MultiHead}(Q, K, V) = \\text{Concat}(\\text{head}_1, \\text{head}_2, …, \\text{head}_h)

    MultiHead(Q,K,V)=Concat(head1,head2,,headh)

  • 最终线性投影(可选,但常用): 将拼接后的结果通过另一个线性变换

    W

    O

    W^O

    WO 投影到最终的输出维度(通常是

    d

    m

    o

    d

    e

    l

    d_{model}

    dmodel)。

    Output

    =

    MultiHead

    (

    Q

    ,

    K

    ,

    V

    )

    W

    O

    \\text{Output} = \\text{MultiHead}(Q, K, V) W^O

    Output=MultiHead(Q,K,V)WO 其中

    W

    O

    W^O

    WO 是维度为

    d

    m

    o

    d

    e

    l

    ×

    d

    m

    o

    d

    e

    l

    d_{model} \\times d_{model}

    dmodel×dmodel 的可学习权重矩阵。

  • 多头注意力的代码实现

    多头注意力的代码实现方法十分简单,只需使用到上面单头注意力的代码即可。

    代码如下:

    class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, head_size, context_length, dropout=0.1):
    super().__init__()
    self.heads = nn.ModuleList([
    Attention(d_model, head_size, context_length, dropout)
    for _ in range(num_heads)
    ])
    self.projection_layer = nn.Linear(d_model, d_model)
    self.dropout = nn.Dropout(dropout)

    def forward(self, x):
    head_outputs = torch.cat([head(x) for head in self.heads], dim=1)
    return self.dropout(self.projection_layer(head_outputs))


    总结

    以上就是今天要讲的内容,本文仅仅简单介绍了Self-Attention的原理和简单实现,Self-Attention是大模型的核心机制需要认真学习。

    赞(0)
    未经允许不得转载:网硕互联帮助中心 » Self-Attention原理和实现代码(Pytorch实现)
    分享到: 更多 (0)

    评论 抢沙发

    评论前必须登录!