云计算百科
云计算领域专业知识百科平台

PyTorch 实战:从 0 开始搭建 Transformer

  • 导入必要的库
  • python

    import math
    import torch
    import torch.nn as nn
    from LabmL_helpers.module import Module
    from labml_n.utils import clone_module_List
    from typing import Optional, List
    from torch.utils.data import DataLoader, TensorDataset
    from torch import optim
    import torch.nn.functional as F

  • Transformer 模型概述 Transformer 是一种序列到序列的模型,通过自注意力机制并行处理整个序列,能同时考虑序列中的所有元素,并学习上下文之间的关系。其架构包括编码器和解码器部分,每部分都由多个相同的层组成,这些层包含自注意力机制、前馈神经网络,以及归一化和 Dropout 步骤。
  • 核心公式
    • 自注意力计算:Attention(Q,K,V)=softmax(dk​​QKT​)V,其中,Q、K、V分别是查询(Query)、键(Key)和值(Value)矩阵,dk​是键的维度。
    • 多头注意力:将输入分割为多个头,分别计算注意力,然后将结果拼接起来。
    • 位置编码:由于 Transformer 不使用循环结构,因此引入位置编码来保留序列中的位置信息。
  • 自注意力机制
    • 核心原理:计算句子在编码过程中每个位置上的注意力权重,然后以权重和的方式来计算整个句子的隐含向量表示。公式中,首先将 query 与 key 的转置做点积,然后将结果除以dk​​ ,再进行 softmax 计算,最后将结果与 value 做矩阵乘法得到 output。除以dk​​是为了防止QKT过大导致 softmax 计算溢出,且可使QKT结果满足均值为 0,方差 1 的分布。QKT计算本质上是余弦相似度,可表示两个向量在方向上的相似度。
    • 实现
  • python

    import numpy as np
    from math import sqrt
    import torch
    from torch import nn

    class Self_Attention(nn.Module):
    # input : batch_size * seq_len * input_dim
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self, input_dim, dim_k, dim_v):
    super(Self_Attention, self).__init__()
    self.q = nn.Linear(input_dim, dim_k)
    self.k = nn.Linear(input_dim, dim_k)
    self.v = nn.Linear(input_dim, dim_v)
    self._norm_fact = 1 / sqrt(dim_k)

    def forward(self, x):
    Q = self.q(x) # Q: batch_size * seq_len * dim_k
    K = self.k(x) # K: batch_size * seq_len * dim_k
    V = self.v(x) # V: batch_size * seq_len * dim_v
    # Q * K.T() # batch_size * seq_len * seq_len
    atten = nn.Softmax(
    dim=-1)(torch.bmm(Q, K.permute(0, 2, 1))) * self._norm_fact
    # Q * K.T() * V # batch_size * seq_len * dim_v
    output = torch.bmm(atten, V)
    return output

    X = torch.randn(4, 3, 2)
    print(X)
    self_atten = Self_Attention(2, 4, 5) # input_dim:2, k_dim:4, v_dim:5
    res = self_atten(X)
    print(res.shape) # [4,3,5]

  • 多头注意力机制 不同于只使用一个注意力池化,将输入x拆分为h份,独立计算h组不同的线性投影来得到各自的 QKV,然后并行计算注意力,最后将h个注意力池化拼接起来并通过另一个可学习的线性投影进行变换以产生输出。每个头可能关注输入的不同部分,可表示更复杂的函数。
  • python

    from math import sqrt
    import torch
    import torch.nn as nn

    class Self_Attention_Muti_Head(nn.Module):
    # input : batch_size * seq_len * input_dim
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self, input_dim, dim_k, dim_v, nums_head):
    super(Self_Attention_Muti_Head, self).__init__()
    assert dim_k % nums_head == 0
    assert dim_v % nums_head == 0
    self.q = nn.Linear(input_dim, dim_k)
    self.k = nn.Linear(input_dim, dim_k)
    self.v = nn.Linear(input_dim, dim_v)
    self.nums_head = nums_head
    self.dim_k = dim_k
    self.dim_v = dim_v
    self._norm_fact = 1 / sqrt(dim_k)

    def forward(self, x):
    Q = self.q(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k //
    self.nums_head)
    K = self.k(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k //
    self.nums_head)
    V = self.v(x).reshape(-1, x.shape[0], x.shape[1], self.dim_v //
    self.nums_head)
    print(x.shape)
    print(Q.size())
    atten = nn.Softmax(dim=-1)(torch.matmul(Q, K.permute(0, 1, 3, 2))) # Q * K.T() # batch_size * seq_len * seq_len
    output = torch.matmul(atten, V).reshape(x.shape[0], x.shape[1], -1) # Q * K.T() * V # batch_size * seq_len * dim_v
    return output

    x = torch.rand(1, 3, 4)
    print(x)
    atten = Self_Attention_Muti_Head(4, 4, 4, 2)
    y = atten(x)
    print(y.shape)

  • 视觉注意力机制 attention 机制本质是利用相关特征图学习权重分布,再用学出来的权重施加在原特征图上最后进行加权求和。计算机视觉上的注意力机制主要分为三种:空间域、通道域、混合域。
    • 空间域:将图片中的空间域信息做对应的空间变换,提取关键信息,对空间进行掩码的生成并打分,代表是 Spatial attention module。
    • 通道域:给每个通道上的信号增加一个权重,代表该通道与关键信息的相关度,权重越大相关度越高。对通道生成掩码 mask 进行打分,代表是 senet、channel attention module。
    • 混合域:空间域的注意力忽略了通道域中的信息,将每个通道的图片特征同等处理,这种做法会将空间域变换方法局限在原始特征提取阶段。
  • 通道域注意力(SENet) 通过全局池化提取通道权重,然后对特征图进行改变,得到加强后的特征图。
  • python

    class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
    super(SELayer, self).__init__()
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    self.fc = nn.Sequential(
    nn.Linear(channel, channel // reduction, bias=False),
    nn.ReLU(inplace=True),
    nn.Linear(channel // reduction, channel, bias=False),
    nn.Sigmoid()
    )

    def forward(self, x):
    b, c, _, _ = x.size()
    y = self.avg_pool(x).view(b, c) # 对应Squeeze操作
    y = self.fc(y).view(b, c, 1, 1) # 对应Excitation操作
    return x * y.expand_as(x)

  • 门控注意力机制(GCT,Gated Channel Transformation) GCT 是一种简单有效的通道间建模关系体系结构,能显著提高卷积网络在视觉任务的泛化能力。论文发现将门控机制放在 Conv 层前面训练效果最好。GCT 包含三个部分:
    • Global Context Embedding:设计了一种全局上下文嵌入模块,用于每个通道的全局上下文信息汇聚,公式为sc​=αc​∥xc​∥2​=αc​{[∑i=1H​∑j=1W​(xci,j​)2]+ϵ}21​。
    • Channel Normalization:对第一步计算的 L2 进行规范化来构建神经元竞争关系,使用跨通道的特征规范化,公式为s^c​=∥s∥2​C​sc​​=[(∑c=1C​sc2​)+ϵ]21​C​sc​​。
    • Gating Adaptation:加入门限机制,公式为x^c​=xc​[1+tanh(γc​s^c​+βc​)] 。
  • python

    class GCT(nn.Module):
    def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False):
    super(GCT, self).__init__()
    self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))
    self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
    self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
    self.epsilon = epsilon
    self.mode = mode
    self.after_relu = after_relu

    def forward(self, x):
    if self.mode == 'l2':
    embedding = (x.pow(2).sum((2, 3), keepdim=True) +
    self.epsilon).pow(0.5) * self.alpha
    norm = self.gamma / \\
    (embedding.pow(2).mean(dim=1, keepdim=True) +
    self.epsilon).pow(0.5)
    elif self.mode == 'l1':
    if not self.after_relu:
    _x = torch.abs(x)
    else:
    _x = x
    embedding = _x.sum((2, 3), keepdim=True) * self.alpha
    norm = self.gamma / \\
    (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)
    gate = 1. + torch.tanh(embedding * norm + self.beta)
    return x * gate

    GCT 建议添加在 Conv 层前,一般可以先冻结原来的模型,来训练 GCT,然后解冻再进行微调。

    赞(0)
    未经允许不得转载:网硕互联帮助中心 » PyTorch 实战:从 0 开始搭建 Transformer
    分享到: 更多 (0)

    评论 抢沙发

    评论前必须登录!