云计算百科
云计算领域专业知识百科平台

解析 Xavier 初始化:原理、推导与实践

解析 Xavier 初始化:原理、推导与实践

在深度神经网络训练中,初始化策略是决定模型能否成功收敛的关键因素之一。本文将解析由 Xavier Glorot 提出的 Xavier 初始化(又称 Glorot 初始化)。


一、神经网络初始化的挑战

在深度神经网络中,信号需要经过多层传递:

  • 前向传播:输入数据逐层变换得到预测结果
  • 反向传播:梯度从输出层回传更新权重

若初始化不当,会导致两大问题:

  • 梯度消失:权重初始化过小 → 激活值方差逐层衰减 → 梯度趋近于零# 错误示例:权重过小
    W = 0.01 * np.random.randn(fan_in, fan_out)

  • 梯度爆炸:权重初始化过大 → 激活值方差逐层放大 → 梯度数值溢出# 错误示例:权重过大
    W = 1.0 * np.random.randn(fan_in, fan_out)

  • Xavier 初始化的核心目标:保持各层激活值方差稳定,确保前向传播的信号强度和反向传播的梯度强度在初始化时不衰减也不爆炸。


    二、核心思想:方差一致性原则

    Xavier 初始化基于以下重要观察:
    理想情况下,任意层的输入信号方差应等于其输出信号方差:
    Var(a(l−1))≈Var(a(l))
    \\text{Var}(a^{(l-1)}) \\approx \\text{Var}(a^{(l)})
    Var(a(l1))Var(a(l))

    其中:

    • a(l−1)a^{(l-1)}a(l1):第 lll 层的输入(即前一层输出)
    • a(l)a^{(l)}a(l):第 lll 层的输出

    通过数学推导(详见第三节),这一目标可转化为对权重分布的约束条件。


    三、数学推导:前向传播的方差分析

    考虑第 lll 层的线性变换(忽略偏置):
    zi(l)=∑j=1ninwij(l)aj(l−1)
    z_i^{(l)} = \\sum_{j=1}^{n_{\\text{in}}} w_{ij}^{(l)} a_j^{(l-1)}
    zi(l)=j=1ninwij(l)aj(l1)

    核心假设:
  • 权重 wij(l)∼W(l)w_{ij}^{(l)} \\sim \\mathcal{W}^{(l)}wij(l)W(l) 独立同分布(i.i.d)且 E[wij(l)]=0\\mathbb{E}[w_{ij}^{(l)}] = 0E[wij(l)]=0
  • 输入 aj(l−1)∼A(l−1)a_j^{(l-1)} \\sim \\mathcal{A}^{(l-1)}aj(l1)A(l1) 独立同分布(i.i.d)且 E[aj(l−1)]=0\\mathbb{E}[a_j^{(l-1)}] = 0E[aj(l1)]=0
  • 权重与输入相互独立:wij(l)⊥ak(l−1)w_{ij}^{(l)} \\perp a_k^{(l-1)}wij(l)ak(l1)
  • 方差推导过程:

    根据方差定义:
    Var(zi(l))=E[(zi(l))2]−(E[zi(l)])2
    \\text{Var}(z_i^{(l)}) = \\mathbb{E}\\left[(z_i^{(l)})^2\\right] – \\left(\\mathbb{E}[z_i^{(l)}]\\right)^2
    Var(zi(l))=E[(zi(l))2](E[zi(l)])2

    步骤1:计算期望 E[zi(l)]\\mathbb{E}[z_i^{(l)}]E[zi(l)]
    E[zi(l)]=E[∑j=1ninwij(l)aj(l−1)]=∑j=1ninE[wij(l)]E[aj(l−1)]=0
    \\mathbb{E}[z_i^{(l)}] = \\mathbb{E}\\left[\\sum_{j=1}^{n_{\\text{in}}} w_{ij}^{(l)} a_j^{(l-1)}\\right] = \\sum_{j=1}^{n_{\\text{in}}} \\mathbb{E}[w_{ij}^{(l)}] \\mathbb{E}[a_j^{(l-1)}] = 0
    E[zi(l)]=E[j=1ninwij(l)aj(l1)]=j=1ninE[wij(l)]E[aj(l1)]=0

    步骤2:计算 E[(zi(l))2]\\mathbb{E}\\left[(z_i^{(l)})^2\\right]E[(zi(l))2]
    E[(zi(l))2]=E[(∑j=1ninwij(l)aj(l−1))2]=E[∑j=1nin∑k=1ninwij(l)wik(l)aj(l−1)ak(l−1)]
    \\mathbb{E}\\left[(z_i^{(l)})^2\\right] = \\mathbb{E}\\left[\\left(\\sum_{j=1}^{n_{\\text{in}}} w_{ij}^{(l)} a_j^{(l-1)}\\right)^2\\right] = \\mathbb{E}\\left[\\sum_{j=1}^{n_{\\text{in}}} \\sum_{k=1}^{n_{\\text{in}}} w_{ij}^{(l)} w_{ik}^{(l)} a_j^{(l-1)} a_k^{(l-1)}\\right]
    E[(zi(l))2]=E(j=1ninwij(l)aj(l1))2=E[j=1nink=1ninwij(l)wik(l)aj(l1)ak(l1)]

    步骤3:分解双重求和
    =∑j=1nin∑k=1ninE[wij(l)wik(l)aj(l−1)ak(l−1)]
    = \\sum_{j=1}^{n_{\\text{in}}} \\sum_{k=1}^{n_{\\text{in}}} \\mathbb{E}\\left[w_{ij}^{(l)} w_{ik}^{(l)} a_j^{(l-1)} a_k^{(l-1)}\\right]
    =j=1nink=1ninE[wij(l)wik(l)aj(l1)ak(l1)]

    步骤4:处理交叉项(j≠kj \\neq kj=k
    j≠kj \\neq kj=k 时,由独立性假设:
    E[wij(l)wik(l)aj(l−1)ak(l−1)]=E[wij(l)]E[wik(l)]E[aj(l−1)]E[ak(l−1)]=0
    \\mathbb{E}\\left[w_{ij}^{(l)} w_{ik}^{(l)} a_j^{(l-1)} a_k^{(l-1)}\\right] = \\mathbb{E}[w_{ij}^{(l)}] \\mathbb{E}[w_{ik}^{(l)}] \\mathbb{E}[a_j^{(l-1)}] \\mathbb{E}[a_k^{(l-1)}] = 0
    E[wij(l)wik(l)aj(l1)ak(l1)]=E[wij(l)]E[wik(l)]E[aj(l1)]E[ak(l1)]=0

    步骤5:处理对角项(j=kj = kj=k
    j=kj = kj=k 时:
    E[(wij(l))2(aj(l−1))2]=E[(wij(l))2]E[(aj(l−1))2]
    \\mathbb{E}\\left[(w_{ij}^{(l)})^2 (a_j^{(l-1)})^2\\right] = \\mathbb{E}\\left[(w_{ij}^{(l)})^2\\right] \\mathbb{E}\\left[(a_j^{(l-1)})^2\\right]
    E[(wij(l))2(aj(l1))2]=E[(wij(l))2]E[(aj(l1))2]

    步骤6:利用零均值性质
    E[(wij(l))2]=Var(w(l))+(E[wij(l)])2=Var(w(l))
    \\mathbb{E}\\left[(w_{ij}^{(l)})^2\\right] = \\text{Var}(w^{(l)}) + \\left(\\mathbb{E}[w_{ij}^{(l)}]\\right)^2 = \\text{Var}(w^{(l)})
    E[(wij(l))2]=Var(w(l))+(E[wij(l)])2=Var(w(l))

    E[(aj(l−1))2]=Var(a(l−1))
    \\mathbb{E}\\left[(a_j^{(l-1)})^2\\right] = \\text{Var}(a^{(l-1)})
    E[(aj(l1))2]=Var(a(l1))

    步骤7:合并结果
    E[(zi(l))2]=∑j=1ninVar(w(l))⋅Var(a(l−1))=nin⋅Var(w(l))⋅Var(a(l−1))
    \\mathbb{E}\\left[(z_i^{(l)})^2\\right] = \\sum_{j=1}^{n_{\\text{in}}} \\text{Var}(w^{(l)}) \\cdot \\text{Var}(a^{(l-1)}) = n_{\\text{in}} \\cdot \\text{Var}(w^{(l)}) \\cdot \\text{Var}(a^{(l-1)})
    E[(zi(l))2]=j=1ninVar(w(l))Var(a(l1))=ninVar(w(l))Var(a(l1))

    最终方差表达式:
    Var(zi(l))=nin⋅Var(w(l))⋅Var(a(l−1))
    \\boxed{\\text{Var}(z_i^{(l)}) = n_{\\text{in}} \\cdot \\text{Var}(w^{(l)}) \\cdot \\text{Var}(a^{(l-1)})}
    Var(zi(l))=ninVar(w(l))Var(a(l1))

    目标约束:

    为使前向传播中方差稳定:
    Var(zi(l))≈Var(aj(l−1))
    \\text{Var}(z_i^{(l)}) \\approx \\text{Var}(a_j^{(l-1)})
    Var(zi(l))Var(aj(l1))

    代入上式得:
    nin⋅Var(w(l))⋅Var(a(l−1))≈Var(a(l−1))
    n_{\\text{in}} \\cdot \\text{Var}(w^{(l)}) \\cdot \\text{Var}(a^{(l-1)}) \\approx \\text{Var}(a^{(l-1)})
    ninVar(w(l))Var(a(l1))Var(a(l1))

    ⇒Var(w(l))≈1nin
    \\Rightarrow \\boxed{\\text{Var}(w^{(l)}) \\approx \\frac{1}{n_{\\text{in}}}}
    Var(w(l))nin1


    四、反向传播的梯度方差分析

    考虑损失函数 L\\mathcal{L}L 对第 lll 层输入的梯度:
    δj(l−1)=∂L∂aj(l−1)=∑i=1nout∂L∂zi(l)∂zi(l)∂aj(l−1)=∑i=1noutδi(l)wij(l)
    \\delta_j^{(l-1)} = \\frac{\\partial \\mathcal{L}}{\\partial a_j^{(l-1)}} = \\sum_{i=1}^{n_{\\text{out}}} \\frac{\\partial \\mathcal{L}}{\\partial z_i^{(l)}} \\frac{\\partial z_i^{(l)}}{\\partial a_j^{(l-1)}} = \\sum_{i=1}^{n_{\\text{out}}} \\delta_i^{(l)} w_{ij}^{(l)}
    δj(l1)=aj(l1)L=i=1noutzi(l)Laj(l1)zi(l)=i=1noutδi(l)wij(l)

    核心假设:
  • 梯度 δi(l)∼Δ(l)\\delta_i^{(l)} \\sim \\Delta^{(l)}δi(l)Δ(l) 独立同分布(i.i.d)且 E[δi(l)]=0\\mathbb{E}[\\delta_i^{(l)}] = 0E[δi(l)]=0
  • 权重 wij(l)∼W(l)w_{ij}^{(l)} \\sim \\mathcal{W}^{(l)}wij(l)W(l) 独立同分布(i.i.d)
  • 梯度与权重相互独立:δi(l)⊥wkj(l)\\delta_i^{(l)} \\perp w_{kj}^{(l)}δi(l)wkj(l)
  • 方差推导过程:

    Var(δj(l−1))=E[(δj(l−1))2]−(E[δj(l−1)])2
    \\text{Var}(\\delta_j^{(l-1)}) = \\mathbb{E}\\left[(\\delta_j^{(l-1)})^2\\right] – \\left(\\mathbb{E}[\\delta_j^{(l-1)}]\\right)^2
    Var(δj(l1))=E[(δj(l1))2](E[δj(l1)])2

    步骤1:计算期望 E[δj(l−1)]\\mathbb{E}[\\delta_j^{(l-1)}]E[δj(l1)]
    E[δj(l−1)]=E[∑i=1noutδi(l)wij(l)]=∑i=1noutE[δi(l)]E[wij(l)]=0
    \\mathbb{E}[\\delta_j^{(l-1)}] = \\mathbb{E}\\left[\\sum_{i=1}^{n_{\\text{out}}} \\delta_i^{(l)} w_{ij}^{(l)}\\right] = \\sum_{i=1}^{n_{\\text{out}}} \\mathbb{E}[\\delta_i^{(l)}] \\mathbb{E}[w_{ij}^{(l)}] = 0
    E[δj(l1)]=E[i=1noutδi(l)wij(l)]=i=1noutE[δi(l)]E[wij(l)]=0

    步骤2:计算 E[(δj(l−1))2]\\mathbb{E}\\left[(\\delta_j^{(l-1)})^2\\right]E[(δj(l1))2]
    E[(δj(l−1))2]=E[(∑i=1noutδi(l)wij(l))2]=E[∑i=1nout∑k=1noutδi(l)δk(l)wij(l)wkj(l)]
    \\mathbb{E}\\left[(\\delta_j^{(l-1)})^2\\right] = \\mathbb{E}\\left[\\left(\\sum_{i=1}^{n_{\\text{out}}} \\delta_i^{(l)} w_{ij}^{(l)}\\right)^2\\right] = \\mathbb{E}\\left[\\sum_{i=1}^{n_{\\text{out}}} \\sum_{k=1}^{n_{\\text{out}}} \\delta_i^{(l)} \\delta_k^{(l)} w_{ij}^{(l)} w_{kj}^{(l)}\\right]
    E[(δj(l1))2]=E(i=1noutδi(l)wij(l))2=E[i=1noutk=1noutδi(l)δk(l)wij(l)wkj(l)]

    步骤3:分解双重求和
    =∑i=1nout∑k=1noutE[δi(l)δk(l)wij(l)wkj(l)]
    = \\sum_{i=1}^{n_{\\text{out}}} \\sum_{k=1}^{n_{\\text{out}}} \\mathbb{E}\\left[\\delta_i^{(l)} \\delta_k^{(l)} w_{ij}^{(l)} w_{kj}^{(l)}\\right]
    =i=1noutk=1noutE[δi(l)δk(l)wij(l)wkj(l)]

    步骤4:处理交叉项(i≠ki \\neq ki=k
    i≠ki \\neq ki=k 时:
    E[δi(l)δk(l)wij(l)wkj(l)]=E[δi(l)]E[δk(l)]E[wij(l)]E[wkj(l)]=0
    \\mathbb{E}\\left[\\delta_i^{(l)} \\delta_k^{(l)} w_{ij}^{(l)} w_{kj}^{(l)}\\right] = \\mathbb{E}[\\delta_i^{(l)}] \\mathbb{E}[\\delta_k^{(l)}] \\mathbb{E}[w_{ij}^{(l)}] \\mathbb{E}[w_{kj}^{(l)}] = 0
    E[δi(l)δk(l)wij(l)wkj(l)]=E[δi(l)]E[δk(l)]E[wij(l)]E[wkj(l)]=0

    步骤5:处理对角项(i=ki = ki=k
    i=ki = ki=k 时:
    E[(δi(l))2(wij(l))2]=E[(δi(l))2]E[(wij(l))2]
    \\mathbb{E}\\left[(\\delta_i^{(l)})^2 (w_{ij}^{(l)})^2\\right] = \\mathbb{E}\\left[(\\delta_i^{(l)})^2\\right] \\mathbb{E}\\left[(w_{ij}^{(l)})^2\\right]
    E[(δi(l))2(wij(l))2]=E[(δi(l))2]E[(wij(l))2]

    步骤6:利用零均值性质
    E[(δi(l))2]=Var(δ(l))
    \\mathbb{E}\\left[(\\delta_i^{(l)})^2\\right] = \\text{Var}(\\delta^{(l)})
    E[(δi(l))2]=Var(δ(l))

    E[(wij(l))2]=Var(w(l))
    \\mathbb{E}\\left[(w_{ij}^{(l)})^2\\right] = \\text{Var}(w^{(l)})
    E[(wij(l))2]=Var(w(l))

    步骤7:合并结果
    E[(δj(l−1))2]=∑i=1noutVar(δ(l))⋅Var(w(l))=nout⋅Var(δ(l))⋅Var(w(l))
    \\mathbb{E}\\left[(\\delta_j^{(l-1)})^2\\right] = \\sum_{i=1}^{n_{\\text{out}}} \\text{Var}(\\delta^{(l)}) \\cdot \\text{Var}(w^{(l)}) = n_{\\text{out}} \\cdot \\text{Var}(\\delta^{(l)}) \\cdot \\text{Var}(w^{(l)})
    E[(δj(l1))2]=i=1noutVar(δ(l))Var(w(l))=noutVar(δ(l))Var(w(l))

    最终方差表达式:
    Var(δj(l−1))=nout⋅Var(w(l))⋅Var(δ(l))
    \\boxed{\\text{Var}(\\delta_j^{(l-1)}) = n_{\\text{out}} \\cdot \\text{Var}(w^{(l)}) \\cdot \\text{Var}(\\delta^{(l)})}
    Var(δj(l1))=noutVar(w(l))Var(δ(l))

    目标约束:

    为使反向传播中梯度方差稳定:
    Var(δj(l−1))≈Var(δi(l))
    \\text{Var}(\\delta_j^{(l-1)}) \\approx \\text{Var}(\\delta_i^{(l)})
    Var(δj(l1))Var(δi(l))

    代入上式得:
    nout⋅Var(w(l))⋅Var(δ(l))≈Var(δ(l))
    n_{\\text{out}} \\cdot \\text{Var}(w^{(l)}) \\cdot \\text{Var}(\\delta^{(l)}) \\approx \\text{Var}(\\delta^{(l)})
    noutVar(w(l))Var(δ(l))Var(δ(l))

    ⇒Var(w(l))≈1nout
    \\Rightarrow \\boxed{\\text{Var}(w^{(l)}) \\approx \\frac{1}{n_{\\text{out}}}}
    Var(w(l))nout1


    五、折中方案:双向调和

    综合前向传播约束和反向传播约束:

    • 前向传播要求:Var(w)≈1nin\\text{Var}(w) \\approx \\dfrac{1}{n_{\\text{in}}}Var(w)nin1
    • 反向传播要求:Var(w)≈1nout\\text{Var}(w) \\approx \\dfrac{1}{n_{\\text{out}}}Var(w)nout1

    Xavier 采用调和平均数平衡两者:
    1Var(w)=12(11nin+11nout)=nin+nout2
    \\frac{1}{\\text{Var}(w)} = \\frac{1}{2}\\left(\\frac{1}{\\frac{1}{n_{\\text{in}}}} + \\frac{1}{\\frac{1}{n_{\\text{out}}}}\\right) = \\frac{n_{\\text{in}} + n_{\\text{out}}}{2}
    Var(w)1=21(nin11+nout11)=2nin+nout

    ⇒Var(w)=2nin+nout
    \\Rightarrow \\boxed{\\text{Var}(w) = \\frac{2}{n_{\\text{in}} + n_{\\text{out}}}}
    Var(w)=nin+nout2

    该方案确保信号在前向传播和梯度在反向传播都能保持稳定,为深层网络训练提供良好起点。

    数学直觉:调和平均数倾向于较小值,这防止了梯度爆炸(方差过大)同时缓解梯度消失(方差过小)


    六、两种具体实现方式

    1. Xavier 均匀分布

    limit=6nin+nout
    \\text{limit} = \\sqrt{\\frac{6}{n_{\\text{in}} + n_{\\text{out}}}}
    limit=nin+nout6

    W∼Uniform(−limit,limit)
    W \\sim \\text{Uniform}(-\\text{limit}, \\text{limit})
    WUniform(limit,limit)

    推导依据:均匀分布 Uniform(−a,a)\\text{Uniform}(-a,a)Uniform(a,a) 的方差为 a2/3a^2/3a2/3

    2. Xavier 正态分布

    σ=2nin+nout
    \\sigma = \\sqrt{\\frac{2}{n_{\\text{in}} + n_{\\text{out}}}}
    σ=nin+nout2

    W∼N(0,σ2)
    W \\sim \\mathcal{N}(0, \\sigma^2)
    WN(0,σ2)


    七、关键细节与注意事项

  • 激活函数适配性:

    • ✅ 适用于线性激活或零中心非线性函数(如 tanh, sigmoid)a = torch.tanh(z) # 在零点附近满足 E[a]=0
    • ❌ 不适用于 ReLU 族(输出均值 > 0),此类网络应使用 He 初始化# ReLU 网络的初始化
      std = math.sqrt(2 / n_in) # He 初始化

  • 网络层适配:

    • 全连接层:直接使用 ninn_{\\text{in}}ninnoutn_{\\text{out}}nout
    • 卷积层:
      nin=kernel_width×kernel_height×in_channelsn_{\\text{in}} = \\text{kernel\\_width} \\times \\text{kernel\\_height} \\times \\text{in\\_channels}nin=kernel_width×kernel_height×in_channels
      nout=kernel_width×kernel_height×out_channelsn_{\\text{out}} = \\text{kernel\\_width} \\times \\text{kernel\\_height} \\times \\text{out\\_channels}nout=kernel_width×kernel_height×out_channels
  • 偏置初始化:通常设为 0 或极小常数(如 0.01)


  • 八、主流框架实现

    PyTorch 示例

    # 均匀分布版本
    linear_layer = nn.Linear(in_features, out_features)
    nn.init.xavier_uniform_(linear_layer.weight, gain=1.0)

    # 正态分布版本
    conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size)
    nn.init.xavier_normal_(conv_layer.weight, gain=1.0)

    gain 参数:用于补偿激活函数导数(如 sigmoid 建议 gain=1,tanh 建议 gain=5/3)

    TensorFlow 示例

    dense = tf.keras.layers.Dense(
    units=64,
    kernel_initializer='glorot_uniform' # Xavier 初始化
    )


    九、总结:为什么 Xavier 初始化有效?

    机制效果
    方差一致性原则 前向传播信号强度保持稳定
    双向约束调和 反向传播梯度强度保持稳定
    零均值分布 避免交叉项干扰方差计算
    按层自适应缩放 适应不同网络结构的尺寸差异

    实践意义:
    Xavier 初始化使深层网络在训练初期就能保持信号和梯度的稳定性,显著缓解梯度消失/爆炸问题,为优化器提供良好的起点。尽管现代网络多使用 ReLU 族激活函数(需 He 初始化),但 Xavier 仍是理解权重初始化理论的基石,其思想被后续多种优化算法继承发展。

    “初始化的本质是为优化过程提供一个平滑的起点” —— Xavier Glorot, AISTATS 2010

    赞(0)
    未经允许不得转载:网硕互联帮助中心 » 解析 Xavier 初始化:原理、推导与实践
    分享到: 更多 (0)

    评论 抢沙发

    评论前必须登录!