{"id":34017,"date":"2025-04-28T22:09:38","date_gmt":"2025-04-28T14:09:38","guid":{"rendered":"https:\/\/www.wsisp.com\/helps\/34017.html"},"modified":"2025-04-28T22:09:38","modified_gmt":"2025-04-28T14:09:38","slug":"ai-%e5%a4%a7%e6%a8%a1%e5%9e%8b%e4%b9%8b-transformer-%e6%9e%b6%e6%9e%84%e6%b7%b1%e5%85%a5%e5%89%96%e6%9e%90","status":"publish","type":"post","link":"https:\/\/www.wsisp.com\/helps\/34017.html","title":{"rendered":"AI \u5927\u6a21\u578b\u4e4b Transformer \u67b6\u6784\u6df1\u5165\u5256\u6790"},"content":{"rendered":"<h2>AI \u5927\u6a21\u578b\u4e4b Transformer \u67b6\u6784\u6df1\u5165\u5256\u6790<\/h2>\n<h3>\u672c\u4eba\u6398\u91d1\u53f7&#xff0c;\u6b22\u8fce\u70b9\u51fb\u5173\u6ce8&#xff1a;\u6398\u91d1\u53f7\u5730\u5740<\/h3>\n<h3>\u672c\u4eba\u516c\u4f17\u53f7&#xff0c;\u6b22\u8fce\u70b9\u51fb\u5173\u6ce8&#xff1a;\u516c\u4f17\u53f7\u5730\u5740<\/h3>\n<h3>\u4e00\u3001\u5f15\u8a00<\/h3>\n<p>\u5728\u4eba\u5de5\u667a\u80fd\u7684\u53d1\u5c55\u5386\u7a0b\u4e2d&#xff0c;Transformer \u67b6\u6784\u65e0\u7591\u662f\u4e00\u5ea7\u5177\u6709\u91cc\u7a0b\u7891\u610f\u4e49\u7684\u4e30\u7891\u3002\u81ea\u4ece 2017 \u5e74 Google \u56e2\u961f\u5728\u8bba\u6587\u300aAttention Is All You Need\u300b\u4e2d\u9996\u6b21\u63d0\u51fa Transformer \u67b6\u6784\u4ee5\u6765&#xff0c;\u5b83\u4fbf\u8fc5\u901f\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406&#xff08;NLP&#xff09;\u9886\u57df\u5f15\u53d1\u4e86\u9769\u547d\u6027\u7684\u53d8\u9769&#xff0c;\u5e76\u9010\u6e10\u62d3\u5c55\u5230\u8ba1\u7b97\u673a\u89c6\u89c9\u3001\u8bed\u97f3\u8bc6\u522b\u7b49\u591a\u4e2a\u9886\u57df\u3002Transformer \u67b6\u6784\u4ee5\u5176\u5353\u8d8a\u7684\u5e76\u884c\u8ba1\u7b97\u80fd\u529b\u3001\u5f3a\u5927\u7684\u957f\u5e8f\u5217\u5904\u7406\u80fd\u529b\u4ee5\u53ca\u51fa\u8272\u7684\u5efa\u6a21\u6548\u679c&#xff0c;\u6210\u4e3a\u4e86\u4f17\u591a\u5148\u8fdb\u5927\u6a21\u578b\u7684\u6838\u5fc3\u57fa\u7840&#xff0c;\u5982 GPT \u7cfb\u5217\u3001BERT \u7b49\u3002<\/p>\n<p>\u672c\u6587\u5c06\u6df1\u5165\u5256\u6790 Transformer \u67b6\u6784&#xff0c;\u4ece\u5176\u6838\u5fc3\u539f\u7406\u3001\u7ec4\u4ef6\u6784\u6210\u5230\u6e90\u7801\u5b9e\u73b0&#xff0c;\u8fdb\u884c\u5168\u65b9\u4f4d\u3001\u7ec6\u81f4\u5165\u5fae\u7684\u5206\u6790\u3002\u901a\u8fc7\u5bf9\u6e90\u7801\u7684\u9010\u884c\u89e3\u8bfb&#xff0c;\u5e2e\u52a9\u8bfb\u8005\u6df1\u5165\u7406\u89e3 Transformer \u67b6\u6784\u7684\u5de5\u4f5c\u673a\u5236&#xff0c;\u4e3a\u8fdb\u4e00\u6b65\u7814\u7a76\u548c\u5e94\u7528\u57fa\u4e8e Transformer \u7684\u5927\u6a21\u578b\u5960\u5b9a\u575a\u5b9e\u7684\u57fa\u7840\u3002<\/p>\n<h3>\u4e8c\u3001Transformer \u67b6\u6784\u6982\u8ff0<\/h3>\n<h4>2.1 \u4f20\u7edf\u5e8f\u5217\u5904\u7406\u6a21\u578b\u7684\u5c40\u9650\u6027<\/h4>\n<p>\u5728 Transformer \u67b6\u6784\u51fa\u73b0\u4e4b\u524d&#xff0c;\u5faa\u73af\u795e\u7ecf\u7f51\u7edc&#xff08;RNN&#xff09;\u53ca\u5176\u53d8\u4f53&#xff0c;\u5982\u957f\u77ed\u671f\u8bb0\u5fc6\u7f51\u7edc&#xff08;LSTM&#xff09;\u548c\u95e8\u63a7\u5faa\u73af\u5355\u5143&#xff08;GRU&#xff09;&#xff0c;\u662f\u5904\u7406\u5e8f\u5217\u6570\u636e\u7684\u4e3b\u6d41\u6a21\u578b\u3002\u7136\u800c&#xff0c;\u8fd9\u4e9b\u6a21\u578b\u5b58\u5728\u4e00\u4e9b\u56fa\u6709\u7684\u5c40\u9650\u6027&#xff1a;<\/p>\n<ul>\n<li>\u987a\u5e8f\u8ba1\u7b97\u95ee\u9898&#xff1a;RNN \u53ca\u5176\u53d8\u4f53\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u65f6&#xff0c;\u9700\u8981\u6309\u987a\u5e8f\u4f9d\u6b21\u5904\u7406\u6bcf\u4e2a\u65f6\u95f4\u6b65\u7684\u8f93\u5165&#xff0c;\u8fd9\u4f7f\u5f97\u5b83\u4eec\u96be\u4ee5\u8fdb\u884c\u5e76\u884c\u8ba1\u7b97&#xff0c;\u4ece\u800c\u9650\u5236\u4e86\u6a21\u578b\u7684\u8bad\u7ec3\u901f\u5ea6\u548c\u5904\u7406\u957f\u5e8f\u5217\u7684\u80fd\u529b\u3002<\/li>\n<li>\u957f\u8ddd\u79bb\u4f9d\u8d56\u95ee\u9898&#xff1a;\u5728\u5904\u7406\u957f\u5e8f\u5217\u65f6&#xff0c;RNN \u53ca\u5176\u53d8\u4f53\u5bb9\u6613\u51fa\u73b0\u68af\u5ea6\u6d88\u5931\u6216\u68af\u5ea6\u7206\u70b8\u7684\u95ee\u9898&#xff0c;\u5bfc\u81f4\u6a21\u578b\u96be\u4ee5\u6355\u6349\u5e8f\u5217\u4e2d\u7684\u957f\u8ddd\u79bb\u4f9d\u8d56\u5173\u7cfb\u3002<\/li>\n<\/ul>\n<h4>2.2 Transformer \u67b6\u6784\u7684\u63d0\u51fa<\/h4>\n<p>\u4e3a\u4e86\u89e3\u51b3\u4f20\u7edf\u5e8f\u5217\u5904\u7406\u6a21\u578b\u7684\u5c40\u9650\u6027&#xff0c;Google \u56e2\u961f\u63d0\u51fa\u4e86 Transformer \u67b6\u6784\u3002Transformer \u67b6\u6784\u6452\u5f03\u4e86\u4f20\u7edf\u7684\u5faa\u73af\u7ed3\u6784&#xff0c;\u5b8c\u5168\u57fa\u4e8e\u6ce8\u610f\u529b\u673a\u5236&#xff08;Attention Mechanism&#xff09;\u6784\u5efa&#xff0c;\u4ece\u800c\u5b9e\u73b0\u4e86\u5e76\u884c\u8ba1\u7b97&#xff0c;\u5927\u5927\u63d0\u9ad8\u4e86\u6a21\u578b\u7684\u8bad\u7ec3\u6548\u7387\u548c\u5904\u7406\u957f\u5e8f\u5217\u7684\u80fd\u529b\u3002<\/p>\n<h4>2.3 Transformer \u67b6\u6784\u7684\u4e3b\u8981\u7279\u70b9<\/h4>\n<ul>\n<li>\u5e76\u884c\u8ba1\u7b97&#xff1a;Transformer \u67b6\u6784\u901a\u8fc7\u81ea\u6ce8\u610f\u529b\u673a\u5236&#xff08;Self-Attention Mechanism&#xff09;\u53ef\u4ee5\u540c\u65f6\u5904\u7406\u5e8f\u5217\u4e2d\u7684\u6240\u6709\u5143\u7d20&#xff0c;\u4ece\u800c\u5b9e\u73b0\u4e86\u5e76\u884c\u8ba1\u7b97&#xff0c;\u63d0\u9ad8\u4e86\u6a21\u578b\u7684\u8bad\u7ec3\u901f\u5ea6\u3002<\/li>\n<li>\u957f\u8ddd\u79bb\u4f9d\u8d56\u5efa\u6a21&#xff1a;\u81ea\u6ce8\u610f\u529b\u673a\u5236\u53ef\u4ee5\u76f4\u63a5\u6355\u6349\u5e8f\u5217\u4e2d\u4efb\u610f\u4e24\u4e2a\u5143\u7d20\u4e4b\u95f4\u7684\u4f9d\u8d56\u5173\u7cfb&#xff0c;\u6709\u6548\u89e3\u51b3\u4e86\u957f\u8ddd\u79bb\u4f9d\u8d56\u95ee\u9898\u3002<\/li>\n<li>\u7075\u6d3b\u6027&#xff1a;Transformer \u67b6\u6784\u53ef\u4ee5\u7075\u6d3b\u5730\u5e94\u7528\u4e8e\u5404\u79cd\u5e8f\u5217\u5904\u7406\u4efb\u52a1&#xff0c;\u5982\u673a\u5668\u7ffb\u8bd1\u3001\u6587\u672c\u751f\u6210\u3001\u95ee\u7b54\u7cfb\u7edf\u7b49\u3002<\/li>\n<\/ul>\n<h3>\u4e09\u3001Transformer \u67b6\u6784\u7684\u6838\u5fc3\u539f\u7406<\/h3>\n<h4>3.1 \u6ce8\u610f\u529b\u673a\u5236&#xff08;Attention Mechanism&#xff09;<\/h4>\n<h5>3.1.1 \u6ce8\u610f\u529b\u673a\u5236\u7684\u57fa\u672c\u6982\u5ff5<\/h5>\n<p>\u6ce8\u610f\u529b\u673a\u5236\u662f\u4e00\u79cd\u6a21\u62df\u4eba\u7c7b\u6ce8\u610f\u529b\u7684\u673a\u5236&#xff0c;\u5b83\u53ef\u4ee5\u8ba9\u6a21\u578b\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u65f6&#xff0c;\u81ea\u52a8\u5730\u5173\u6ce8\u5e8f\u5217\u4e2d\u7684\u91cd\u8981\u90e8\u5206\u3002\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d&#xff0c;\u6ce8\u610f\u529b\u673a\u5236\u53ef\u4ee5\u5e2e\u52a9\u6a21\u578b\u66f4\u597d\u5730\u6355\u6349\u4e0a\u4e0b\u6587\u4fe1\u606f&#xff0c;\u4ece\u800c\u63d0\u9ad8\u6a21\u578b\u7684\u6027\u80fd\u3002<\/p>\n<h5>3.1.2 \u7f29\u653e\u70b9\u79ef\u6ce8\u610f\u529b&#xff08;Scaled Dot-Product Attention&#xff09;<\/h5>\n<p>\u7f29\u653e\u70b9\u79ef\u6ce8\u610f\u529b\u662f Transformer \u67b6\u6784\u4e2d\u4f7f\u7528\u7684\u4e00\u79cd\u6ce8\u610f\u529b\u673a\u5236&#xff0c;\u5176\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b&#xff1a; (\\\\text{Attention}(Q, K, V) &#061; \\\\text{softmax}(\\\\frac{QK^T}{\\\\sqrt{d_k}})V) \u5176\u4e2d&#xff0c;Q \u662f\u67e5\u8be2\u77e9\u9635&#xff08;Query Matrix&#xff09;&#xff0c;K \u662f\u952e\u77e9\u9635&#xff08;Key Matrix&#xff09;&#xff0c;V \u662f\u503c\u77e9\u9635&#xff08;Value Matrix&#xff09;&#xff0c;(d_k) \u662f\u952e\u5411\u91cf\u7684\u7ef4\u5ea6\u3002<\/p>\n<p>\u4ee5\u4e0b\u662f\u7f29\u653e\u70b9\u79ef\u6ce8\u610f\u529b\u7684 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn<span class=\"token punctuation\">.<\/span>functional <span class=\"token keyword\">as<\/span> F<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u7f29\u653e\u70b9\u79ef\u6ce8\u610f\u529b\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">ScaledDotProductAttention<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_k<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span>ScaledDotProductAttention<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u952e\u5411\u91cf\u7684\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>d_k <span class=\"token operator\">&#061;<\/span> d_k  <\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> q<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">,<\/span> v<span class=\"token punctuation\">,<\/span> mask<span class=\"token operator\">&#061;<\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u8ba1\u7b97 Q \u548c K \u7684\u8f6c\u7f6e\u7684\u70b9\u79ef<\/span><br \/>\n        attn_scores <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>matmul<span class=\"token punctuation\">(<\/span>q<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">.<\/span>transpose<span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u7f29\u653e\u70b9\u79ef<\/span><br \/>\n        attn_scores <span class=\"token operator\">&#061;<\/span> attn_scores <span class=\"token operator\">\/<\/span> torch<span class=\"token punctuation\">.<\/span>sqrt<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>tensor<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>d_k<span class=\"token punctuation\">,<\/span> dtype<span class=\"token operator\">&#061;<\/span>torch<span class=\"token punctuation\">.<\/span>float32<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span>  <\/p>\n<p>        <span class=\"token keyword\">if<\/span> mask <span class=\"token keyword\">is<\/span> <span class=\"token keyword\">not<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n            <span class=\"token comment\"># \u5982\u679c\u6709\u63a9\u7801&#xff0c;\u5c06\u63a9\u7801\u4f4d\u7f6e\u7684\u6ce8\u610f\u529b\u5206\u6570\u8bbe\u4e3a\u8d1f\u65e0\u7a77<\/span><br \/>\n            attn_scores <span class=\"token operator\">&#061;<\/span> attn_scores<span class=\"token punctuation\">.<\/span>masked_fill<span class=\"token punctuation\">(<\/span>mask <span class=\"token operator\">&#061;&#061;<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1e9<\/span><span class=\"token punctuation\">)<\/span>  <\/p>\n<p>        <span class=\"token comment\"># \u5bf9\u6ce8\u610f\u529b\u5206\u6570\u8fdb\u884c softmax \u64cd\u4f5c&#xff0c;\u5f97\u5230\u6ce8\u610f\u529b\u6743\u91cd<\/span><br \/>\n        attn_weights <span class=\"token operator\">&#061;<\/span> F<span class=\"token punctuation\">.<\/span>softmax<span class=\"token punctuation\">(<\/span>attn_scores<span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u8ba1\u7b97\u6ce8\u610f\u529b\u8f93\u51fa<\/span><br \/>\n        output <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>matmul<span class=\"token punctuation\">(<\/span>attn_weights<span class=\"token punctuation\">,<\/span> v<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> output<span class=\"token punctuation\">,<\/span> attn_weights<\/p>\n<h5>3.1.3 \u591a\u5934\u6ce8\u610f\u529b&#xff08;Multi-Head Attention&#xff09;<\/h5>\n<p>\u591a\u5934\u6ce8\u610f\u529b\u662f\u7f29\u653e\u70b9\u79ef\u6ce8\u610f\u529b\u7684\u6269\u5c55&#xff0c;\u5b83\u901a\u8fc7\u5c06\u67e5\u8be2\u3001\u952e\u548c\u503c\u5206\u522b\u6295\u5f71\u5230\u591a\u4e2a\u4f4e\u7ef4\u5b50\u7a7a\u95f4\u4e2d&#xff0c;\u7136\u540e\u5728\u6bcf\u4e2a\u5b50\u7a7a\u95f4\u4e2d\u72ec\u7acb\u5730\u8ba1\u7b97\u6ce8\u610f\u529b&#xff0c;\u6700\u540e\u5c06\u6240\u6709\u5b50\u7a7a\u95f4\u7684\u6ce8\u610f\u529b\u8f93\u51fa\u62fc\u63a5\u8d77\u6765\u5e76\u8fdb\u884c\u7ebf\u6027\u53d8\u6362&#xff0c;\u5f97\u5230\u6700\u7ec8\u7684\u6ce8\u610f\u529b\u8f93\u51fa\u3002<\/p>\n<p>\u591a\u5934\u6ce8\u610f\u529b\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b&#xff1a; (\\\\text{MultiHead}(Q, K, V) &#061; \\\\text{Concat}(\\\\text{head}_1, \\\\ldots, \\\\text{head}_h)W^O) \u5176\u4e2d&#xff0c;(\\\\text{head}_i &#061; \\\\text{Attention}(QW_i^Q, KW_i^K, VW_i<sup>V))&#xff0c;(W_i<\/sup>Q)\u3001(W_i<sup>K)\u3001(W_i<\/sup>V) \u662f\u6295\u5f71\u77e9\u9635&#xff0c;(W^O) \u662f\u8f93\u51fa\u77e9\u9635\u3002<\/p>\n<p>\u4ee5\u4e0b\u662f\u591a\u5934\u6ce8\u610f\u529b\u7684 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u591a\u5934\u6ce8\u610f\u529b\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">MultiHeadAttention<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span>MultiHeadAttention<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6a21\u578b\u7684\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>d_model <span class=\"token operator\">&#061;<\/span> d_model<br \/>\n        <span class=\"token comment\"># \u6ce8\u610f\u529b\u5934\u7684\u6570\u91cf<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>num_heads <span class=\"token operator\">&#061;<\/span> num_heads<br \/>\n        <span class=\"token comment\"># \u6bcf\u4e2a\u5934\u7684\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>d_k <span class=\"token operator\">&#061;<\/span> d_model <span class=\"token operator\">\/\/<\/span> num_heads  <\/p>\n<p>        <span class=\"token comment\"># \u5b9a\u4e49\u67e5\u8be2\u3001\u952e\u548c\u503c\u7684\u7ebf\u6027\u53d8\u6362\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>W_q <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>W_k <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>W_v <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u8f93\u51fa\u7684\u7ebf\u6027\u53d8\u6362\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>W_o <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span>  <\/p>\n<p>        <span class=\"token comment\"># \u5b9a\u4e49\u7f29\u653e\u70b9\u79ef\u6ce8\u610f\u529b\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>attention <span class=\"token operator\">&#061;<\/span> ScaledDotProductAttention<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>d_k<span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> q<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">,<\/span> v<span class=\"token punctuation\">,<\/span> mask<span class=\"token operator\">&#061;<\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        batch_size <span class=\"token operator\">&#061;<\/span> q<span class=\"token punctuation\">.<\/span>size<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u5bf9\u67e5\u8be2\u3001\u952e\u548c\u503c\u8fdb\u884c\u7ebf\u6027\u53d8\u6362<\/span><br \/>\n        Q <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>W_q<span class=\"token punctuation\">(<\/span>q<span class=\"token punctuation\">)<\/span><br \/>\n        K <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>W_k<span class=\"token punctuation\">(<\/span>k<span class=\"token punctuation\">)<\/span><br \/>\n        V <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>W_v<span class=\"token punctuation\">(<\/span>v<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u5c06\u67e5\u8be2\u3001\u952e\u548c\u503c\u5206\u5272\u6210\u591a\u4e2a\u5934<\/span><br \/>\n        Q <span class=\"token operator\">&#061;<\/span> Q<span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>batch_size<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>num_heads<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>d_k<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>transpose<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        K <span class=\"token operator\">&#061;<\/span> K<span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>batch_size<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>num_heads<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>d_k<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>transpose<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        V <span class=\"token operator\">&#061;<\/span> V<span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>batch_size<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>num_heads<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>d_k<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>transpose<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token keyword\">if<\/span> mask <span class=\"token keyword\">is<\/span> <span class=\"token keyword\">not<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n            <span class=\"token comment\"># \u5982\u679c\u6709\u63a9\u7801&#xff0c;\u5c06\u63a9\u7801\u6269\u5c55\u5230\u6bcf\u4e2a\u5934<\/span><br \/>\n            mask <span class=\"token operator\">&#061;<\/span> mask<span class=\"token punctuation\">.<\/span>unsqueeze<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u8ba1\u7b97\u591a\u5934\u6ce8\u610f\u529b\u8f93\u51fa<\/span><br \/>\n        output<span class=\"token punctuation\">,<\/span> attn_weights <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>attention<span class=\"token punctuation\">(<\/span>Q<span class=\"token punctuation\">,<\/span> K<span class=\"token punctuation\">,<\/span> V<span class=\"token punctuation\">,<\/span> mask<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u5c06\u591a\u5934\u6ce8\u610f\u529b\u8f93\u51fa\u62fc\u63a5\u8d77\u6765<\/span><br \/>\n        output <span class=\"token operator\">&#061;<\/span> output<span class=\"token punctuation\">.<\/span>transpose<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>contiguous<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>batch_size<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>d_model<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5bf9\u62fc\u63a5\u540e\u7684\u8f93\u51fa\u8fdb\u884c\u7ebf\u6027\u53d8\u6362<\/span><br \/>\n        output <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>W_o<span class=\"token punctuation\">(<\/span>output<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token keyword\">return<\/span> output<span class=\"token punctuation\">,<\/span> attn_weights<\/p>\n<h4>3.2 \u4f4d\u7f6e\u7f16\u7801&#xff08;Positional Encoding&#xff09;<\/h4>\n<p>\u7531\u4e8e Transformer \u67b6\u6784\u6452\u5f03\u4e86\u4f20\u7edf\u7684\u5faa\u73af\u7ed3\u6784&#xff0c;\u5b83\u65e0\u6cd5\u81ea\u52a8\u6355\u6349\u5e8f\u5217\u4e2d\u5143\u7d20\u7684\u4f4d\u7f6e\u4fe1\u606f\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898&#xff0c;Transformer \u67b6\u6784\u5f15\u5165\u4e86\u4f4d\u7f6e\u7f16\u7801&#xff08;Positional Encoding&#xff09;&#xff0c;\u5c06\u4f4d\u7f6e\u4fe1\u606f\u6dfb\u52a0\u5230\u8f93\u5165\u5e8f\u5217\u7684\u8bcd\u5411\u91cf\u4e2d\u3002<\/p>\n<p>\u4f4d\u7f6e\u7f16\u7801\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b&#xff1a; (PE_{(pos, 2i)} &#061; \\\\sin(\\\\frac{pos}{10000^{\\\\frac{2i}{d_{model}}}})) (PE_{(pos, 2i &#043; 1)} &#061; \\\\cos(\\\\frac{pos}{10000^{\\\\frac{2i}{d_{model}}}})) \u5176\u4e2d&#xff0c;pos \u662f\u5143\u7d20\u7684\u4f4d\u7f6e&#xff0c;i \u662f\u7ef4\u5ea6\u7d22\u5f15&#xff0c;(d_{model}) \u662f\u6a21\u578b\u7684\u7ef4\u5ea6\u3002<\/p>\n<p>\u4ee5\u4e0b\u662f\u4f4d\u7f6e\u7f16\u7801\u7684 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u4f4d\u7f6e\u7f16\u7801\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">PositionalEncoding<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> max_len<span class=\"token operator\">&#061;<\/span><span class=\"token number\">5000<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span>PositionalEncoding<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6a21\u578b\u7684\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>d_model <span class=\"token operator\">&#061;<\/span> d_model  <\/p>\n<p>        <span class=\"token comment\"># \u521b\u5efa\u4f4d\u7f6e\u7f16\u7801\u77e9\u9635<\/span><br \/>\n        pe <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>zeros<span class=\"token punctuation\">(<\/span>max_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span><br \/>\n        position <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>arange<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> max_len<span class=\"token punctuation\">,<\/span> dtype<span class=\"token operator\">&#061;<\/span>torch<span class=\"token punctuation\">.<\/span><span class=\"token builtin\">float<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>unsqueeze<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        div_term <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>exp<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>arange<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token builtin\">float<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">*<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span>torch<span class=\"token punctuation\">.<\/span>log<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>tensor<span class=\"token punctuation\">(<\/span><span class=\"token number\">10000.0<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">\/<\/span> d_model<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        pe<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">:<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>sin<span class=\"token punctuation\">(<\/span>position <span class=\"token operator\">*<\/span> div_term<span class=\"token punctuation\">)<\/span><br \/>\n        pe<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">:<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>cos<span class=\"token punctuation\">(<\/span>position <span class=\"token operator\">*<\/span> div_term<span class=\"token punctuation\">)<\/span><br \/>\n        pe <span class=\"token operator\">&#061;<\/span> pe<span class=\"token punctuation\">.<\/span>unsqueeze<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5c06\u4f4d\u7f6e\u7f16\u7801\u77e9\u9635\u6ce8\u518c\u4e3a\u7f13\u51b2\u533a&#xff0c;\u4e0d\u53c2\u4e0e\u6a21\u578b\u8bad\u7ec3<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>register_buffer<span class=\"token punctuation\">(<\/span><span class=\"token string\">&#039;pe&#039;<\/span><span class=\"token punctuation\">,<\/span> pe<span class=\"token punctuation\">)<\/span>  <\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u5c06\u4f4d\u7f6e\u7f16\u7801\u6dfb\u52a0\u5230\u8f93\u5165\u5e8f\u5217\u7684\u8bcd\u5411\u91cf\u4e2d<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> x <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>pe<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span>x<span class=\"token punctuation\">.<\/span>size<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">]<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> x<\/p>\n<h4>3.3 \u524d\u9988\u795e\u7ecf\u7f51\u7edc&#xff08;Feed-Forward Network&#xff09;<\/h4>\n<p>\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u662f Transformer \u67b6\u6784\u4e2d\u7684\u53e6\u4e00\u4e2a\u91cd\u8981\u7ec4\u4ef6&#xff0c;\u5b83\u7531\u4e24\u4e2a\u7ebf\u6027\u5c42\u548c\u4e00\u4e2a\u6fc0\u6d3b\u51fd\u6570\u7ec4\u6210\u3002\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b&#xff1a; (\\\\text{FFN}(x) &#061; \\\\text{max}(0, xW_1 &#043; b_1)W_2 &#043; b_2) \u5176\u4e2d&#xff0c;(W_1)\u3001(W_2) \u662f\u6743\u91cd\u77e9\u9635&#xff0c;(b_1)\u3001(b_2) \u662f\u504f\u7f6e\u5411\u91cf\u3002<\/p>\n<p>\u4ee5\u4e0b\u662f\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u7684 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">PositionwiseFeedForward<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span>PositionwiseFeedForward<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u7b2c\u4e00\u4e2a\u7ebf\u6027\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>fc1 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u7b2c\u4e8c\u4e2a\u7ebf\u6027\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>fc2 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>d_ff<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u6fc0\u6d3b\u51fd\u6570<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>relu <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>ReLU<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u7b2c\u4e00\u4e2a\u7ebf\u6027\u5c42<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>fc1<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6fc0\u6d3b\u51fd\u6570<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>relu<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u7b2c\u4e8c\u4e2a\u7ebf\u6027\u5c42<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>fc2<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> x<\/p>\n<h3>\u56db\u3001Transformer \u67b6\u6784\u7684\u7ec4\u4ef6\u6784\u6210<\/h3>\n<h4>4.1 \u7f16\u7801\u5668&#xff08;Encoder&#xff09;<\/h4>\n<h5>4.1.1 \u7f16\u7801\u5668\u7684\u7ed3\u6784<\/h5>\n<p>\u7f16\u7801\u5668\u662f Transformer \u67b6\u6784\u7684\u4e00\u90e8\u5206&#xff0c;\u5b83\u7531\u591a\u4e2a\u76f8\u540c\u7684\u7f16\u7801\u5668\u5c42&#xff08;Encoder Layer&#xff09;\u5806\u53e0\u800c\u6210\u3002\u6bcf\u4e2a\u7f16\u7801\u5668\u5c42\u5305\u542b\u4e24\u4e2a\u5b50\u5c42&#xff1a;\u591a\u5934\u6ce8\u610f\u529b\u5b50\u5c42\u548c\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u5b50\u5c42\u3002<\/p>\n<h5>4.1.2 \u7f16\u7801\u5668\u5c42\u7684\u5b9e\u73b0<\/h5>\n<p>\u4ee5\u4e0b\u662f\u7f16\u7801\u5668\u5c42\u7684 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u7f16\u7801\u5668\u5c42\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">EncoderLayer<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> dropout<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span>EncoderLayer<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u591a\u5934\u6ce8\u610f\u529b\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>self_attn <span class=\"token operator\">&#061;<\/span> MultiHeadAttention<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>feed_forward <span class=\"token operator\">&#061;<\/span> PositionwiseFeedForward<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u7b2c\u4e00\u4e2a\u5c42\u5f52\u4e00\u5316\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>norm1 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>LayerNorm<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u7b2c\u4e8c\u4e2a\u5c42\u5f52\u4e00\u5316\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>norm2 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>LayerNorm<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49 dropout \u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>dropout <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Dropout<span class=\"token punctuation\">(<\/span>dropout<span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> mask<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u591a\u5934\u6ce8\u610f\u529b\u5b50\u5c42<\/span><br \/>\n        attn_output<span class=\"token punctuation\">,<\/span> _ <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>self_attn<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> mask<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6b8b\u5dee\u8fde\u63a5\u548c\u5c42\u5f52\u4e00\u5316<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>norm1<span class=\"token punctuation\">(<\/span>x <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>dropout<span class=\"token punctuation\">(<\/span>attn_output<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u524d\u9988\u795e\u7ecf\u7f51\u7edc\u5b50\u5c42<\/span><br \/>\n        ff_output <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>feed_forward<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6b8b\u5dee\u8fde\u63a5\u548c\u5c42\u5f52\u4e00\u5316<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>norm2<span class=\"token punctuation\">(<\/span>x <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>dropout<span class=\"token punctuation\">(<\/span>ff_output<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token keyword\">return<\/span> x<\/p>\n<h5>4.1.3 \u7f16\u7801\u5668\u7684\u5b9e\u73b0<\/h5>\n<p>\u4ee5\u4e0b\u662f\u7f16\u7801\u5668\u7684 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u7f16\u7801\u5668\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">Encoder<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> num_layers<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> input_vocab_size<span class=\"token punctuation\">,<\/span> maximum_position_encoding<span class=\"token punctuation\">,<\/span> dropout<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span>Encoder<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6a21\u578b\u7684\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>d_model <span class=\"token operator\">&#061;<\/span> d_model<br \/>\n        <span class=\"token comment\"># \u7f16\u7801\u5668\u5c42\u7684\u6570\u91cf<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>num_layers <span class=\"token operator\">&#061;<\/span> num_layers<\/p>\n<p>        <span class=\"token comment\"># \u5b9a\u4e49\u8bcd\u5d4c\u5165\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>embedding <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Embedding<span class=\"token punctuation\">(<\/span>input_vocab_size<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u4f4d\u7f6e\u7f16\u7801\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>pos_encoding <span class=\"token operator\">&#061;<\/span> PositionalEncoding<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> maximum_position_encoding<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u5b9a\u4e49\u7f16\u7801\u5668\u5c42\u5217\u8868<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>enc_layers <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>ModuleList<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>EncoderLayer<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> dropout<span class=\"token punctuation\">)<\/span> <span class=\"token keyword\">for<\/span> _ <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span>num_layers<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49 dropout \u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>dropout <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Dropout<span class=\"token punctuation\">(<\/span>dropout<span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> mask<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u8bcd\u5d4c\u5165<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>embedding<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u7f29\u653e\u8bcd\u5d4c\u5165<\/span><br \/>\n        x <span class=\"token operator\">*&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>sqrt<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>tensor<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>d_model<span class=\"token punctuation\">,<\/span> dtype<span class=\"token operator\">&#061;<\/span>torch<span class=\"token punctuation\">.<\/span>float32<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6dfb\u52a0\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>pos_encoding<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5e94\u7528 dropout<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>dropout<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u4f9d\u6b21\u901a\u8fc7\u6bcf\u4e2a\u7f16\u7801\u5668\u5c42<\/span><br \/>\n        <span class=\"token keyword\">for<\/span> i <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>num_layers<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n            x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>enc_layers<span class=\"token punctuation\">[<\/span>i<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">,<\/span> mask<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token keyword\">return<\/span> x<\/p>\n<h4>4.2 \u89e3\u7801\u5668&#xff08;Decoder&#xff09;<\/h4>\n<h5>4.2.1 \u89e3\u7801\u5668\u7684\u7ed3\u6784<\/h5>\n<p>\u89e3\u7801\u5668\u662f Transformer \u67b6\u6784\u7684\u53e6\u4e00\u90e8\u5206&#xff0c;\u5b83\u4e5f\u7531\u591a\u4e2a\u76f8\u540c\u7684\u89e3\u7801\u5668\u5c42&#xff08;Decoder Layer&#xff09;\u5806\u53e0\u800c\u6210\u3002\u6bcf\u4e2a\u89e3\u7801\u5668\u5c42\u5305\u542b\u4e09\u4e2a\u5b50\u5c42&#xff1a;\u591a\u5934\u81ea\u6ce8\u610f\u529b\u5b50\u5c42\u3001\u7f16\u7801\u5668 &#8211; \u89e3\u7801\u5668\u6ce8\u610f\u529b\u5b50\u5c42\u548c\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u5b50\u5c42\u3002<\/p>\n<h5>4.2.2 \u89e3\u7801\u5668\u5c42\u7684\u5b9e\u73b0<\/h5>\n<p>\u4ee5\u4e0b\u662f\u89e3\u7801\u5668\u5c42\u7684 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u89e3\u7801\u5668\u5c42\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">DecoderLayer<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> dropout<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span>DecoderLayer<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u591a\u5934\u81ea\u6ce8\u610f\u529b\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>self_attn <span class=\"token operator\">&#061;<\/span> MultiHeadAttention<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u7f16\u7801\u5668 &#8211; \u89e3\u7801\u5668\u6ce8\u610f\u529b\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>enc_dec_attn <span class=\"token operator\">&#061;<\/span> MultiHeadAttention<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>feed_forward <span class=\"token operator\">&#061;<\/span> PositionwiseFeedForward<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u7b2c\u4e00\u4e2a\u5c42\u5f52\u4e00\u5316\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>norm1 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>LayerNorm<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u7b2c\u4e8c\u4e2a\u5c42\u5f52\u4e00\u5316\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>norm2 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>LayerNorm<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u7b2c\u4e09\u4e2a\u5c42\u5f52\u4e00\u5316\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>norm3 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>LayerNorm<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49 dropout \u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>dropout <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Dropout<span class=\"token punctuation\">(<\/span>dropout<span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> enc_output<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">,<\/span> tgt_mask<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u591a\u5934\u81ea\u6ce8\u610f\u529b\u5b50\u5c42<\/span><br \/>\n        attn_output1<span class=\"token punctuation\">,<\/span> _ <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>self_attn<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> tgt_mask<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6b8b\u5dee\u8fde\u63a5\u548c\u5c42\u5f52\u4e00\u5316<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>norm1<span class=\"token punctuation\">(<\/span>x <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>dropout<span class=\"token punctuation\">(<\/span>attn_output1<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u7f16\u7801\u5668 &#8211; \u89e3\u7801\u5668\u6ce8\u610f\u529b\u5b50\u5c42<\/span><br \/>\n        attn_output2<span class=\"token punctuation\">,<\/span> _ <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>enc_dec_attn<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">,<\/span> enc_output<span class=\"token punctuation\">,<\/span> enc_output<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6b8b\u5dee\u8fde\u63a5\u548c\u5c42\u5f52\u4e00\u5316<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>norm2<span class=\"token punctuation\">(<\/span>x <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>dropout<span class=\"token punctuation\">(<\/span>attn_output2<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u524d\u9988\u795e\u7ecf\u7f51\u7edc\u5b50\u5c42<\/span><br \/>\n        ff_output <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>feed_forward<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6b8b\u5dee\u8fde\u63a5\u548c\u5c42\u5f52\u4e00\u5316<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>norm3<span class=\"token punctuation\">(<\/span>x <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>dropout<span class=\"token punctuation\">(<\/span>ff_output<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token keyword\">return<\/span> x<\/p>\n<h5>4.2.3 \u89e3\u7801\u5668\u7684\u5b9e\u73b0<\/h5>\n<p>\u4ee5\u4e0b\u662f\u89e3\u7801\u5668\u7684 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u89e3\u7801\u5668\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">Decoder<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> num_layers<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> target_vocab_size<span class=\"token punctuation\">,<\/span> maximum_position_encoding<span class=\"token punctuation\">,<\/span> dropout<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span>Decoder<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6a21\u578b\u7684\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>d_model <span class=\"token operator\">&#061;<\/span> d_model<br \/>\n        <span class=\"token comment\"># \u89e3\u7801\u5668\u5c42\u7684\u6570\u91cf<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>num_layers <span class=\"token operator\">&#061;<\/span> num_layers<\/p>\n<p>        <span class=\"token comment\"># \u5b9a\u4e49\u8bcd\u5d4c\u5165\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>embedding <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Embedding<span class=\"token punctuation\">(<\/span>target_vocab_size<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u4f4d\u7f6e\u7f16\u7801\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>pos_encoding <span class=\"token operator\">&#061;<\/span> PositionalEncoding<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> maximum_position_encoding<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u5b9a\u4e49\u89e3\u7801\u5668\u5c42\u5217\u8868<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>dec_layers <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>ModuleList<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>DecoderLayer<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> dropout<span class=\"token punctuation\">)<\/span> <span class=\"token keyword\">for<\/span> _ <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span>num_layers<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49 dropout \u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>dropout <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Dropout<span class=\"token punctuation\">(<\/span>dropout<span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> enc_output<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">,<\/span> tgt_mask<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u8bcd\u5d4c\u5165<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>embedding<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u7f29\u653e\u8bcd\u5d4c\u5165<\/span><br \/>\n        x <span class=\"token operator\">*&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>sqrt<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>tensor<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>d_model<span class=\"token punctuation\">,<\/span> dtype<span class=\"token operator\">&#061;<\/span>torch<span class=\"token punctuation\">.<\/span>float32<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6dfb\u52a0\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>pos_encoding<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5e94\u7528 dropout<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>dropout<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u4f9d\u6b21\u901a\u8fc7\u6bcf\u4e2a\u89e3\u7801\u5668\u5c42<\/span><br \/>\n        <span class=\"token keyword\">for<\/span> i <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>num_layers<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n            x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>dec_layers<span class=\"token punctuation\">[<\/span>i<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">,<\/span> enc_output<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">,<\/span> tgt_mask<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token keyword\">return<\/span> x<\/p>\n<h4>4.3 \u5168\u8fde\u63a5\u5c42&#xff08;Final Linear Layer&#xff09;<\/h4>\n<p>\u5168\u8fde\u63a5\u5c42\u662f Transformer \u67b6\u6784\u7684\u6700\u540e\u4e00\u5c42&#xff0c;\u5b83\u5c06\u89e3\u7801\u5668\u7684\u8f93\u51fa\u6620\u5c04\u5230\u76ee\u6807\u8bcd\u6c47\u8868\u7684\u5927\u5c0f&#xff0c;\u7528\u4e8e\u9884\u6d4b\u4e0b\u4e00\u4e2a\u8bcd\u7684\u6982\u7387\u5206\u5e03\u3002<\/p>\n<p>\u4ee5\u4e0b\u662f\u5168\u8fde\u63a5\u5c42\u7684 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u5168\u8fde\u63a5\u5c42\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">FinalLinearLayer<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> target_vocab_size<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span>FinalLinearLayer<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u7ebf\u6027\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>linear <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> target_vocab_size<span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u7ebf\u6027\u53d8\u6362<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>linear<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> x<\/p>\n<h4>4.4 Transformer \u6a21\u578b\u7684\u6574\u4f53\u5b9e\u73b0<\/h4>\n<p>\u4ee5\u4e0b\u662f Transformer \u6a21\u578b\u7684\u6574\u4f53 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49 Transformer \u6a21\u578b\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">Transformer<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> num_layers<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> input_vocab_size<span class=\"token punctuation\">,<\/span><br \/>\n                 target_vocab_size<span class=\"token punctuation\">,<\/span> pe_input<span class=\"token punctuation\">,<\/span> pe_target<span class=\"token punctuation\">,<\/span> dropout<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span>Transformer<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u7f16\u7801\u5668<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>encoder <span class=\"token operator\">&#061;<\/span> Encoder<span class=\"token punctuation\">(<\/span>num_layers<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> input_vocab_size<span class=\"token punctuation\">,<\/span> pe_input<span class=\"token punctuation\">,<\/span> dropout<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u89e3\u7801\u5668<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>decoder <span class=\"token operator\">&#061;<\/span> Decoder<span class=\"token punctuation\">(<\/span>num_layers<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> target_vocab_size<span class=\"token punctuation\">,<\/span> pe_target<span class=\"token punctuation\">,<\/span> dropout<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b9a\u4e49\u5168\u8fde\u63a5\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>final_layer <span class=\"token operator\">&#061;<\/span> FinalLinearLayer<span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> target_vocab_size<span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> src<span class=\"token punctuation\">,<\/span> tgt<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">,<\/span> tgt_mask<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u7f16\u7801\u5668\u524d\u5411\u4f20\u64ad<\/span><br \/>\n        enc_output <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>encoder<span class=\"token punctuation\">(<\/span>src<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u89e3\u7801\u5668\u524d\u5411\u4f20\u64ad<\/span><br \/>\n        dec_output <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>decoder<span class=\"token punctuation\">(<\/span>tgt<span class=\"token punctuation\">,<\/span> enc_output<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">,<\/span> tgt_mask<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5168\u8fde\u63a5\u5c42\u524d\u5411\u4f20\u64ad<\/span><br \/>\n        final_output <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>final_layer<span class=\"token punctuation\">(<\/span>dec_output<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> final_output<\/p>\n<h3>\u4e94\u3001Transformer \u67b6\u6784\u7684\u8bad\u7ec3\u4e0e\u4f18\u5316<\/h3>\n<h4>5.1 \u635f\u5931\u51fd\u6570&#xff08;Loss Function&#xff09;<\/h4>\n<p>\u5728\u8bad\u7ec3 Transformer \u6a21\u578b\u65f6&#xff0c;\u901a\u5e38\u4f7f\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570&#xff08;Cross-Entropy Loss Function&#xff09;\u6765\u8861\u91cf\u6a21\u578b\u9884\u6d4b\u7ed3\u679c\u4e0e\u771f\u5b9e\u6807\u7b7e\u4e4b\u95f4\u7684\u5dee\u5f02\u3002<\/p>\n<p>\u4ee5\u4e0b\u662f\u4f7f\u7528 PyTorch \u5b9e\u73b0\u7684\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<br \/>\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570<\/span><br \/>\ncriterion <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>CrossEntropyLoss<span class=\"token punctuation\">(<\/span>ignore_index<span class=\"token operator\">&#061;<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<h4>5.2 \u4f18\u5316\u5668&#xff08;Optimizer&#xff09;<\/h4>\n<p>\u5728\u8bad\u7ec3 Transformer \u6a21\u578b\u65f6&#xff0c;\u901a\u5e38\u4f7f\u7528 Adam \u4f18\u5316\u5668&#xff08;Adam Optimizer&#xff09;\u6765\u66f4\u65b0\u6a21\u578b\u7684\u53c2\u6570\u3002<\/p>\n<p>\u4ee5\u4e0b\u662f\u4f7f\u7528 PyTorch \u5b9e\u73b0\u7684 Adam \u4f18\u5316\u5668&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>optim <span class=\"token keyword\">as<\/span> optim<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u6a21\u578b<\/span><br \/>\nmodel <span class=\"token operator\">&#061;<\/span> Transformer<span class=\"token punctuation\">(<\/span>num_layers<span class=\"token operator\">&#061;<\/span><span class=\"token number\">6<\/span><span class=\"token punctuation\">,<\/span> d_model<span class=\"token operator\">&#061;<\/span><span class=\"token number\">512<\/span><span class=\"token punctuation\">,<\/span> num_heads<span class=\"token operator\">&#061;<\/span><span class=\"token number\">8<\/span><span class=\"token punctuation\">,<\/span> d_ff<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2048<\/span><span class=\"token punctuation\">,<\/span><br \/>\n                    input_vocab_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">10000<\/span><span class=\"token punctuation\">,<\/span> target_vocab_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">10000<\/span><span class=\"token punctuation\">,<\/span><br \/>\n                    pe_input<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1000<\/span><span class=\"token punctuation\">,<\/span> pe_target<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1000<\/span><span class=\"token punctuation\">,<\/span> dropout<span class=\"token operator\">&#061;<\/span><span class=\"token number\">0.1<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u4f18\u5316\u5668<\/span><br \/>\noptimizer <span class=\"token operator\">&#061;<\/span> optim<span class=\"token punctuation\">.<\/span>Adam<span class=\"token punctuation\">(<\/span>model<span class=\"token punctuation\">.<\/span>parameters<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> lr<span class=\"token operator\">&#061;<\/span><span class=\"token number\">0.0001<\/span><span class=\"token punctuation\">,<\/span> betas<span class=\"token operator\">&#061;<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">0.9<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0.98<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> eps<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1e<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">9<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<h4>5.3 \u5b66\u4e60\u7387\u8c03\u5ea6\u5668&#xff08;Learning Rate Scheduler&#xff09;<\/h4>\n<p>\u4e3a\u4e86\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u52a8\u6001\u8c03\u6574\u5b66\u4e60\u7387&#xff0c;\u901a\u5e38\u4f7f\u7528\u5b66\u4e60\u7387\u8c03\u5ea6\u5668&#xff08;Learning Rate Scheduler&#xff09;\u3002\u5728 Transformer \u6a21\u578b\u4e2d&#xff0c;\u5e38\u7528\u7684\u5b66\u4e60\u7387\u8c03\u5ea6\u5668\u662f\u57fa\u4e8e\u70ed\u8eab&#xff08;Warmup&#xff09;\u7b56\u7565\u7684\u8c03\u5ea6\u5668\u3002<\/p>\n<p>\u4ee5\u4e0b\u662f\u4f7f\u7528 PyTorch \u5b9e\u73b0\u7684\u57fa\u4e8e\u70ed\u8eab\u7b56\u7565\u7684\u5b66\u4e60\u7387\u8c03\u5ea6\u5668&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>optim <span class=\"token keyword\">as<\/span> optim<br \/>\n<span class=\"token keyword\">import<\/span> math<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u57fa\u4e8e\u70ed\u8eab\u7b56\u7565\u7684\u5b66\u4e60\u7387\u8c03\u5ea6\u5668\u7c7b<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">WarmupScheduler<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> optimizer<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> warmup_steps<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u4f18\u5316\u5668<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>optimizer <span class=\"token operator\">&#061;<\/span> optimizer<br \/>\n        <span class=\"token comment\"># \u6a21\u578b\u7684\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>d_model <span class=\"token operator\">&#061;<\/span> d_model<br \/>\n        <span class=\"token comment\"># \u70ed\u8eab\u6b65\u6570<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>warmup_steps <span class=\"token operator\">&#061;<\/span> warmup_steps<br \/>\n        <span class=\"token comment\"># \u5f53\u524d\u6b65\u6570<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>step_num <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">0<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">step<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u589e\u52a0\u5f53\u524d\u6b65\u6570<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>step_num <span class=\"token operator\">&#043;&#061;<\/span> <span class=\"token number\">1<\/span><br \/>\n        <span class=\"token comment\"># \u8ba1\u7b97\u5b66\u4e60\u7387<\/span><br \/>\n        lr <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>d_model <span class=\"token operator\">**<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">0.5<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">*<\/span> <span class=\"token builtin\">min<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>step_num <span class=\"token operator\">**<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">0.5<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>step_num <span class=\"token operator\">*<\/span> <span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>warmup_steps <span class=\"token operator\">**<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1.5<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u66f4\u65b0\u4f18\u5316\u5668\u7684\u5b66\u4e60\u7387<\/span><br \/>\n        <span class=\"token keyword\">for<\/span> param_group <span class=\"token keyword\">in<\/span> self<span class=\"token punctuation\">.<\/span>optimizer<span class=\"token punctuation\">.<\/span>param_groups<span class=\"token punctuation\">:<\/span><br \/>\n            param_group<span class=\"token punctuation\">[<\/span><span class=\"token string\">&#039;lr&#039;<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> lr<br \/>\n        <span class=\"token comment\"># \u6267\u884c\u4f18\u5316\u5668\u7684 step \u65b9\u6cd5<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>optimizer<span class=\"token punctuation\">.<\/span>step<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">zero_grad<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u6267\u884c\u4f18\u5316\u5668\u7684 zero_grad \u65b9\u6cd5<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>optimizer<span class=\"token punctuation\">.<\/span>zero_grad<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<h4>5.4 \u8bad\u7ec3\u5faa\u73af&#xff08;Training Loop&#xff09;<\/h4>\n<p>\u4ee5\u4e0b\u662f\u4e00\u4e2a\u7b80\u5355\u7684 Transformer \u6a21\u578b\u8bad\u7ec3\u5faa\u73af\u7684 Python \u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u8bad\u7ec3\u53c2\u6570<\/span><br \/>\nnum_epochs <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">10<\/span><br \/>\nwarmup_steps <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">4000<\/span><\/p>\n<p><span class=\"token comment\"># \u5b9a\u4e49\u5b66\u4e60\u7387\u8c03\u5ea6\u5668<\/span><br \/>\nscheduler <span class=\"token operator\">&#061;<\/span> WarmupScheduler<span class=\"token punctuation\">(<\/span>optimizer<span class=\"token punctuation\">,<\/span> d_model<span class=\"token operator\">&#061;<\/span><span class=\"token number\">512<\/span><span class=\"token punctuation\">,<\/span> warmup_steps<span class=\"token operator\">&#061;<\/span>warmup_steps<span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token comment\"># \u8bad\u7ec3\u5faa\u73af<\/span><br \/>\n<span class=\"token keyword\">for<\/span> epoch <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span>num_epochs<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    total_loss <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">0<\/span><br \/>\n    <span class=\"token keyword\">for<\/span> src<span class=\"token punctuation\">,<\/span> tgt <span class=\"token keyword\">in<\/span> dataloader<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u751f\u6210\u6e90\u5e8f\u5217\u548c\u76ee\u6807\u5e8f\u5217\u7684\u63a9\u7801<\/span><br \/>\n        src_mask <span class=\"token operator\">&#061;<\/span> create_src_mask<span class=\"token punctuation\">(<\/span>src<span class=\"token punctuation\">)<\/span><br \/>\n        tgt_mask <span class=\"token operator\">&#061;<\/span> create_tgt_mask<span class=\"token punctuation\">(<\/span>tgt<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u524d\u5411\u4f20\u64ad<\/span><br \/>\n        output <span class=\"token operator\">&#061;<\/span> model<span class=\"token punctuation\">(<\/span>src<span class=\"token punctuation\">,<\/span> tgt<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">,<\/span> tgt_mask<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u8ba1\u7b97\u635f\u5931<\/span><br \/>\n        loss <span class=\"token operator\">&#061;<\/span> criterion<span class=\"token punctuation\">(<\/span>output<span class=\"token punctuation\">.<\/span>contiguous<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> output<span class=\"token punctuation\">.<\/span>size<span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> tgt<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">.<\/span>contiguous<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u53cd\u5411\u4f20\u64ad<\/span><br \/>\n        scheduler<span class=\"token punctuation\">.<\/span>zero_grad<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        loss<span class=\"token punctuation\">.<\/span>backward<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u66f4\u65b0\u53c2\u6570<\/span><br \/>\n        scheduler<span class=\"token punctuation\">.<\/span>step<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        total_loss <span class=\"token operator\">&#043;&#061;<\/span> loss<span class=\"token punctuation\">.<\/span>item<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">print<\/span><span class=\"token punctuation\">(<\/span><span class=\"token string-interpolation\"><span class=\"token string\">f&#039;Epoch <\/span><span class=\"token interpolation\"><span class=\"token punctuation\">{<\/span>epoch <span class=\"token operator\">&#043;<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">}<\/span><\/span><span class=\"token string\">, Loss: <\/span><span class=\"token interpolation\"><span class=\"token punctuation\">{<\/span>total_loss <span class=\"token operator\">\/<\/span> <span class=\"token builtin\">len<\/span><span class=\"token punctuation\">(<\/span>dataloader<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">}<\/span><\/span><span class=\"token string\">&#039;<\/span><\/span><span class=\"token punctuation\">)<\/span><\/p>\n<h3>\u516d\u3001Transformer \u67b6\u6784\u7684\u5e94\u7528\u6848\u4f8b<\/h3>\n<h4>6.1 \u673a\u5668\u7ffb\u8bd1&#xff08;Machine Translation&#xff09;<\/h4>\n<p>Transformer \u67b6\u6784\u5728\u673a\u5668\u7ffb\u8bd1\u4efb\u52a1\u4e2d\u53d6\u5f97\u4e86\u5de8\u5927\u7684\u6210\u529f\u3002\u901a\u8fc7\u5c06\u6e90\u8bed\u8a00\u5e8f\u5217\u8f93\u5165\u5230\u7f16\u7801\u5668\u4e2d&#xff0c;\u89e3\u7801\u5668\u6839\u636e\u7f16\u7801\u5668\u7684\u8f93\u51fa\u751f\u6210\u76ee\u6807\u8bed\u8a00\u5e8f\u5217\u3002<\/p>\n<p>\u4ee5\u4e0b\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u673a\u5668\u7ffb\u8bd1\u793a\u4f8b\u4ee3\u7801&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token comment\"># \u52a0\u8f7d\u9884\u8bad\u7ec3\u7684 Transformer \u6a21\u578b<\/span><br \/>\nmodel <span class=\"token operator\">&#061;<\/span> Transformer<span class=\"token punctuation\">(<\/span>num_layers<span class=\"token operator\">&#061;<\/span><span class=\"token number\">6<\/span><span class=\"token punctuation\">,<\/span> d_model<span class=\"token operator\">&#061;<\/span><span class=\"token number\">512<\/span><span class=\"token punctuation\">,<\/span> num_heads<span class=\"token operator\">&#061;<\/span><span class=\"token number\">8<\/span><span class=\"token punctuation\">,<\/span> d_ff<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2048<\/span><span class=\"token punctuation\">,<\/span><br \/>\n                    input_vocab_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">10000<\/span><span class=\"token punctuation\">,<\/span> target_vocab_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">10000<\/span><span class=\"token punctuation\">,<\/span><br \/>\n                    pe_input<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1000<\/span><span class=\"token punctuation\">,<\/span> pe_target<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1000<\/span><span class=\"token punctuation\">,<\/span> dropout<span class=\"token operator\">&#061;<\/span><span class=\"token number\">0.1<\/span><span class=\"token punctuation\">)<\/span><br \/>\nmodel<span class=\"token punctuation\">.<\/span>load_state_dict<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>load<span class=\"token punctuation\">(<\/span><span class=\"token string\">&#039;transformer_model.pth&#039;<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\nmodel<span class=\"token punctuation\">.<\/span><span class=\"token builtin\">eval<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token comment\"># \u8f93\u5165\u6e90\u8bed\u8a00\u5e8f\u5217<\/span><br \/>\nsrc <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>tensor<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">[<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">5<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span><br \/>\n<span class=\"token comment\"># \u751f\u6210\u6e90\u5e8f\u5217\u7684\u63a9\u7801<\/span><br \/>\nsrc_mask <span class=\"token operator\">&#061;<\/span> create_src_mask<span class=\"token punctuation\">(<\/span>src<span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token comment\"># \u521d\u59cb\u5316\u76ee\u6807\u5e8f\u5217<\/span><br \/>\ntgt <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>tensor<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">[<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token comment\"># \u751f\u6210\u7ffb\u8bd1\u7ed3\u679c<\/span><br \/>\n<span class=\"token keyword\">for<\/span> i <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">10<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u751f\u6210\u76ee\u6807\u5e8f\u5217\u7684\u63a9\u7801<\/span><br \/>\n    tgt_mask <span class=\"token operator\">&#061;<\/span> create_tgt_mask<span class=\"token punctuation\">(<\/span>tgt<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u524d\u5411\u4f20\u64ad<\/span><br \/>\n    output <span class=\"token operator\">&#061;<\/span> model<span class=\"token punctuation\">(<\/span>src<span class=\"token punctuation\">,<\/span> tgt<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">,<\/span> tgt_mask<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u83b7\u53d6\u9884\u6d4b\u7684\u4e0b\u4e00\u4e2a\u8bcd<\/span><br \/>\n    next_word <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>argmax<span class=\"token punctuation\">(<\/span>output<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>unsqueeze<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u9884\u6d4b\u7684\u4e0b\u4e00\u4e2a\u8bcd\u6dfb\u52a0\u5230\u76ee\u6807\u5e8f\u5217\u4e2d<\/span><br \/>\n    tgt <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>cat<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>tgt<span class=\"token punctuation\">,<\/span> next_word<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token keyword\">print<\/span><span class=\"token punctuation\">(<\/span>tgt<span class=\"token punctuation\">)<\/span><\/p>\n<h4>6.2 \u6587\u672c\u751f\u6210&#xff08;Text Generation&#xff09;<\/h4>\n<p>Transformer \u67b6\u6784\u4e5f\u5e7f\u6cdb\u5e94\u7528\u4e8e\u6587\u672c\u751f\u6210\u4efb\u52a1&#xff0c;\u5982\u6545\u4e8b\u751f\u6210\u3001\u8bd7\u6b4c\u751f\u6210\u7b49\u3002\u901a\u8fc7\u4e0d\u65ad\u5730\u9884\u6d4b\u4e0b\u4e00\u4e2a\u8bcd&#xff0c;\u751f\u6210\u5b8c\u6574\u7684\u6587\u672c\u5e8f\u5217\u3002<\/p>\n<p>\u4ee5\u4e0b\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6587\u672c\u751f\u6210\u793a\u4f8b\u4ee3\u7801&#xff1a;<\/p>\n<p>python<\/p>\n<p><span class=\"token comment\"># \u52a0\u8f7d\u9884\u8bad\u7ec3\u7684 Transformer \u6a21\u578b<\/span><br \/>\nmodel <span class=\"token operator\">&#061;<\/span> Transformer<span class=\"token punctuation\">(<\/span>num_layers<span class=\"token operator\">&#061;<\/span><span class=\"token number\">6<\/span><span class=\"token punctuation\">,<\/span> d_model<span class=\"token operator\">&#061;<\/span><span class=\"token number\">512<\/span><span class=\"token punctuation\">,<\/span> num_heads<span class=\"token operator\">&#061;<\/span><span class=\"token number\">8<\/span><span class=\"token punctuation\">,<\/span> d_ff<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2048<\/span><span class=\"token punctuation\">,<\/span><br \/>\n                    input_vocab_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">10000<\/span><span class=\"token punctuation\">,<\/span> target_vocab_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">10000<\/span><span class=\"token punctuation\">,<\/span><br \/>\n                    pe_input<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1000<\/span><span class=\"token punctuation\">,<\/span> pe_target<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1000<\/span><span class=\"token punctuation\">,<\/span> dropout<span class=\"token operator\">&#061;<\/span><span class=\"token number\">0.1<\/span><span class=\"token punctuation\">)<\/span><br \/>\nmodel<span class=\"token punctuation\">.<\/span>load_state_dict<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>load<span class=\"token punctuation\">(<\/span><span class=\"token string\">&#039;transformer_model.pth&#039;<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\nmodel<span class=\"token punctuation\">.<\/span><span class=\"token builtin\">eval<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token comment\"># \u8f93\u5165\u8d77\u59cb\u6587\u672c<\/span><br \/>\nstart_text <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>tensor<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">[<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span><br \/>\n<span class=\"token comment\"># \u751f\u6210\u8d77\u59cb\u6587\u672c\u7684\u63a9\u7801<\/span><br \/>\nsrc_mask <span class=\"token operator\">&#061;<\/span> create_src_mask<span class=\"token punctuation\">(<\/span>start_text<span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token comment\"># \u521d\u59cb\u5316\u76ee\u6807\u5e8f\u5217<\/span><br \/>\ntgt <span class=\"token operator\">&#061;<\/span> start_text<\/p>\n<p><span class=\"token comment\"># \u751f\u6210\u6587\u672c\u5e8f\u5217<\/span><br \/>\n<span class=\"token keyword\">for<\/span> i <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">20<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u751f\u6210\u76ee\u6807\u5e8f\u5217\u7684\u63a9\u7801<\/span><br \/>\n    tgt_mask <span class=\"token operator\">&#061;<\/span> create_tgt_mask<span class=\"token punctuation\">(<\/span>tgt<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u524d\u5411\u4f20\u64ad<\/span><br \/>\n    output <span class=\"token operator\">&#061;<\/span> model<span class=\"token punctuation\">(<\/span>start_text<span class=\"token punctuation\">,<\/span> tgt<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">,<\/span> tgt_mask<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u83b7\u53d6\u9884\u6d4b\u7684\u4e0b\u4e00\u4e2a\u8bcd<\/span><br \/>\n    next_word <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>argmax<span class=\"token punctuation\">(<\/span>output<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>unsqueeze<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u9884\u6d4b\u7684\u4e0b\u4e00\u4e2a\u8bcd\u6dfb\u52a0\u5230\u76ee\u6807\u5e8f\u5217\u4e2d<\/span><br \/>\n    tgt <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>cat<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>tgt<span class=\"token punctuation\">,<\/span> next_word<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token keyword\">print<\/span><span class=\"token punctuation\">(<\/span>tgt<span class=\"token punctuation\">)<\/span><\/p>\n<h3>\u4e03\u3001\u603b\u7ed3\u4e0e\u5c55\u671b<\/h3>\n<h4>7.1 \u603b\u7ed3<\/h4>\n<p>\u672c\u6587\u6df1\u5165\u5256\u6790\u4e86 Transformer \u67b6\u6784&#xff0c;\u4ece\u5176\u6838\u5fc3\u539f\u7406\u3001\u7ec4\u4ef6\u6784\u6210\u5230\u6e90\u7801\u5b9e\u73b0\u8fdb\u884c\u4e86\u5168\u9762\u7684\u5206\u6790\u3002Transformer \u67b6\u6784\u4ee5\u5176\u5e76\u884c\u8ba1\u7b97\u80fd\u529b\u3001\u957f\u8ddd\u79bb\u4f9d\u8d56\u5efa\u6a21\u80fd\u529b\u548c\u7075\u6d3b\u6027&#xff0c;\u6210\u4e3a\u4e86\u73b0\u4ee3 AI \u5927\u6a21\u578b\u7684\u6838\u5fc3\u57fa\u7840\u3002\u901a\u8fc7\u81ea\u6ce8\u610f\u529b\u673a\u5236\u548c\u4f4d\u7f6e\u7f16\u7801&#xff0c;Transformer \u67b6\u6784\u80fd\u591f\u6709\u6548\u5730\u6355\u6349\u5e8f\u5217\u4e2d\u7684\u4e0a\u4e0b\u6587\u4fe1\u606f&#xff0c;\u4ece\u800c\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u3001\u8ba1\u7b97\u673a\u89c6\u89c9\u7b49\u591a\u4e2a\u9886\u57df\u53d6\u5f97\u4e86\u4f18\u5f02\u7684\u6027\u80fd\u3002<\/p>\n<p>\u5728\u6e90\u7801\u5b9e\u73b0\u65b9\u9762&#xff0c;\u6211\u4eec\u8be6\u7ec6\u4ecb\u7ecd\u4e86\u7f29\u653e\u70b9\u79ef\u6ce8\u610f\u529b\u3001\u591a\u5934\u6ce8\u610f\u529b\u3001\u4f4d\u7f6e\u7f16\u7801\u3001\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u3001\u7f16\u7801\u5668\u3001\u89e3\u7801\u5668\u3001\u5168\u8fde\u63a5\u5c42\u7b49\u7ec4\u4ef6\u7684\u5b9e\u73b0\u7ec6\u8282&#xff0c;\u5e76\u7ed9\u51fa\u4e86\u5b8c\u6574\u7684 Transformer \u6a21\u578b\u7684\u4ee3\u7801\u5b9e\u73b0\u3002\u540c\u65f6&#xff0c;\u6211\u4eec\u8fd8\u4ecb\u7ecd\u4e86 Transformer \u6a21\u578b\u7684\u8bad\u7ec3\u4e0e\u4f18\u5316\u65b9\u6cd5&#xff0c;\u5305\u62ec\u635f\u5931\u51fd\u6570\u3001\u4f18\u5316\u5668\u3001\u5b66\u4e60\u7387\u8c03\u5ea6\u5668\u548c\u8bad\u7ec3\u5faa\u73af\u7b49\u3002<\/p>\n<h4>7.2 \u5c55\u671b<\/h4>\n<p>\u5c3d\u7ba1 Transformer \u67b6\u6784\u5df2\u7ecf\u53d6\u5f97\u4e86\u5de8\u5927\u7684\u6210\u529f&#xff0c;\u4f46\u4ecd\u7136\u5b58\u5728\u4e00\u4e9b\u6311\u6218\u548c\u6539\u8fdb\u7684\u7a7a\u95f4&#xff1a;<\/p>\n<ul>\n<li>\n<p>\u8ba1\u7b97\u8d44\u6e90\u9700\u6c42&#xff1a;Transformer \u67b6\u6784\u7684\u8ba1\u7b97\u590d\u6742\u5ea6\u8f83\u9ad8&#xff0c;\u9700\u8981\u5927\u91cf\u7684\u8ba1\u7b97\u8d44\u6e90\u548c\u5185\u5b58\u3002\u672a\u6765\u7684\u7814\u7a76\u53ef\u4ee5\u63a2\u7d22\u5982\u4f55\u4f18\u5316 Transformer \u67b6\u6784\u7684\u8ba1\u7b97\u6548\u7387&#xff0c;\u51cf\u5c11\u8ba1\u7b97\u8d44\u6e90\u7684\u9700\u6c42\u3002<\/p>\n<\/li>\n<li>\n<p>\u53ef\u89e3\u91ca\u6027&#xff1a;Transformer \u67b6\u6784\u662f\u4e00\u79cd\u9ed1\u76d2\u6a21\u578b&#xff0c;\u5176\u51b3\u7b56\u8fc7\u7a0b\u96be\u4ee5\u89e3\u91ca\u3002\u672a\u6765\u7684\u7814\u7a76\u53ef\u4ee5\u81f4\u529b\u4e8e\u63d0\u9ad8 Transformer \u6a21\u578b\u7684\u53ef\u89e3\u91ca\u6027&#xff0c;\u4f7f\u5176\u66f4\u52a0\u900f\u660e\u548c\u53ef\u4fe1\u3002<\/p>\n<\/li>\n<li>\n<p>\u957f\u5e8f\u5217\u5904\u7406\u80fd\u529b&#xff1a;\u867d\u7136 Transformer \u67b6\u6784\u5728\u5904\u7406\u957f\u5e8f\u5217\u65b9\u9762\u5177\u6709\u4e00\u5b9a\u7684\u4f18\u52bf&#xff0c;\u4f46\u5728\u5904\u7406\u6781\u957f\u5e8f\u5217\u65f6\u4ecd\u7136\u5b58\u5728\u6311\u6218\u3002\u672a\u6765\u7684\u7814\u7a76\u53ef\u4ee5\u63a2\u7d22\u5982\u4f55\u8fdb\u4e00\u6b65\u63d0\u9ad8 Transformer \u67b6\u6784\u7684\u957f\u5e8f\u5217\u5904\u7406\u80fd\u529b\u3002<\/p>\n<\/li>\n<\/ul>\n<p>\u968f\u7740\u4eba\u5de5\u667a\u80fd\u6280\u672f\u7684\u4e0d\u65ad\u53d1\u5c55&#xff0c;Transformer \u67b6\u6784\u6709\u671b\u5728\u66f4\u591a\u7684\u9886\u57df\u5f97\u5230\u5e94\u7528\u548c\u62d3\u5c55\u3002\u4f8b\u5982&#xff0c;\u5728\u533b\u7597\u9886\u57df&#xff0c;Transformer \u67b6\u6784\u53ef\u4ee5\u7528\u4e8e\u533b\u5b66\u56fe\u50cf\u5206\u6790\u3001\u75be\u75c5\u9884\u6d4b\u7b49&#xff1b;\u5728\u91d1\u878d\u9886\u57df&#xff0c;Transformer \u67b6\u6784\u53ef\u4ee5\u7528\u4e8e\u98ce\u9669\u8bc4\u4f30\u3001\u80a1\u7968\u9884\u6d4b\u7b49\u3002\u76f8\u4fe1\u5728\u672a\u6765&#xff0c;Transformer \u67b6\u6784\u5c06\u4e3a\u4eba\u5de5\u667a\u80fd\u7684\u53d1\u5c55\u5e26\u6765\u66f4\u591a\u7684\u7a81\u7834\u548c\u521b\u65b0\u3002<\/p>\n<p>\u4ee5\u4e0a\u5185\u5bb9\u8be6\u7ec6\u4ecb\u7ecd\u4e86 Transformer \u67b6\u6784\u7684\u539f\u7406\u3001\u5b9e\u73b0\u548c\u5e94\u7528&#xff0c;\u5e0c\u671b\u80fd\u591f\u5e2e\u52a9\u8bfb\u8005\u6df1\u5165\u7406\u89e3 Transformer \u67b6\u6784&#xff0c;\u5e76\u4e3a\u8fdb\u4e00\u6b65\u7684\u7814\u7a76\u548c\u5e94\u7528\u63d0\u4f9b\u53c2\u8003\u3002\u5728\u5b9e\u9645\u5e94\u7528\u4e2d&#xff0c;\u8bfb\u8005\u53ef\u4ee5\u6839\u636e\u5177\u4f53\u9700\u6c42\u5bf9\u4ee3\u7801\u8fdb\u884c\u8c03\u6574\u548c\u4f18\u5316&#xff0c;\u4ee5\u5b9e\u73b0\u66f4\u597d\u7684\u6027\u80fd\u3002\u540c\u65f6&#xff0c;\u8bfb\u8005\u4e5f\u53ef\u4ee5\u5173\u6ce8 Transformer \u67b6\u6784\u7684\u6700\u65b0\u7814\u7a76\u8fdb\u5c55&#xff0c;\u4e0d\u65ad\u63a2\u7d22\u5176\u5728\u4e0d\u540c\u9886\u57df\u7684\u5e94\u7528\u6f5c\u529b\u3002<\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u6587\u7ae0\u6d4f\u89c8\u9605\u8bfb1.6k\u6b21\uff0c\u70b9\u8d5e75\u6b21\uff0c\u6536\u85cf37\u6b21\u3002\u5728 Transformer \u67b6\u6784\u51fa\u73b0\u4e4b\u524d\uff0c\u5faa\u73af\u795e\u7ecf\u7f51\u7edc\uff08RNN\uff09\u53ca\u5176\u53d8\u4f53\uff0c\u5982\u957f\u77ed\u671f\u8bb0\u5fc6\u7f51\u7edc\uff08LSTM\uff09\u548c\u95e8\u63a7\u5faa\u73af\u5355\u5143\uff08GRU\uff09\uff0c\u662f\u5904\u7406\u5e8f\u5217\u6570\u636e\u7684\u4e3b\u6d41\u6a21\u578b\u3002\u987a\u5e8f\u8ba1\u7b97\u95ee\u9898\uff1aRNN \u53ca\u5176\u53d8\u4f53\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u9700\u8981\u6309\u987a\u5e8f\u4f9d\u6b21\u5904\u7406\u6bcf\u4e2a\u65f6\u95f4\u6b65\u7684\u8f93\u5165\uff0c\u8fd9\u4f7f\u5f97\u5b83\u4eec\u96be\u4ee5\u8fdb\u884c\u5e76\u884c\u8ba1\u7b97\uff0c\u4ece\u800c\u9650\u5236\u4e86\u6a21\u578b\u7684\u8bad\u7ec3\u901f\u5ea6\u548c\u5904\u7406\u957f\u5e8f\u5217\u7684\u80fd\u529b\u3002\u957f\u8ddd\u79bb\u4f9d\u8d56\u95ee\u9898\uff1a\u5728\u5904\u7406\u957f\u5e8f\u5217\u65f6\uff0cRNN \u53ca\u5176\u53d8\u4f53\u5bb9\u6613\u51fa\u73b0\u68af\u5ea6\u6d88\u5931\u6216\u68af\u5ea6\u7206\u70b8\u7684\u95ee\u9898\uff0c\u5bfc\u81f4\u6a21\u578b\u96be\u4ee5\u6355\u6349\u5e8f\u5217\u4e2d\u7684\u957f\u8ddd\u79bb\u4f9d\u8d56\u5173\u7cfb\u3002_ai\u7684transformer\u67b6\u6784\u63d0\u51fa<\/p>\n","protected":false},"author":2,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[1],"tags":[587,2354,2873,841,50,207,188,86],"topic":[],"class_list":["post-34017","post","type-post","status-publish","format-standard","hentry","category-server","tag-587","tag-ai-","tag-rnn","tag-transformer","tag-50","tag-207","tag-188","tag-86"],"yoast_head":"<!-- This site is optimized with the Yoast SEO plugin v20.3 - https:\/\/yoast.com\/wordpress\/plugins\/seo\/ -->\n<title>AI \u5927\u6a21\u578b\u4e4b Transformer \u67b6\u6784\u6df1\u5165\u5256\u6790 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3<\/title>\n<meta name=\"robots\" content=\"index, follow, max-snippet:-1, max-image-preview:large, max-video-preview:-1\" \/>\n<link rel=\"canonical\" href=\"https:\/\/www.wsisp.com\/helps\/34017.html\" \/>\n<meta property=\"og:locale\" content=\"zh_CN\" \/>\n<meta property=\"og:type\" content=\"article\" \/>\n<meta property=\"og:title\" content=\"AI \u5927\u6a21\u578b\u4e4b Transformer \u67b6\u6784\u6df1\u5165\u5256\u6790 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3\" \/>\n<meta property=\"og:description\" content=\"\u6587\u7ae0\u6d4f\u89c8\u9605\u8bfb1.6k\u6b21\uff0c\u70b9\u8d5e75\u6b21\uff0c\u6536\u85cf37\u6b21\u3002\u5728 Transformer \u67b6\u6784\u51fa\u73b0\u4e4b\u524d\uff0c\u5faa\u73af\u795e\u7ecf\u7f51\u7edc\uff08RNN\uff09\u53ca\u5176\u53d8\u4f53\uff0c\u5982\u957f\u77ed\u671f\u8bb0\u5fc6\u7f51\u7edc\uff08LSTM\uff09\u548c\u95e8\u63a7\u5faa\u73af\u5355\u5143\uff08GRU\uff09\uff0c\u662f\u5904\u7406\u5e8f\u5217\u6570\u636e\u7684\u4e3b\u6d41\u6a21\u578b\u3002\u987a\u5e8f\u8ba1\u7b97\u95ee\u9898\uff1aRNN \u53ca\u5176\u53d8\u4f53\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u9700\u8981\u6309\u987a\u5e8f\u4f9d\u6b21\u5904\u7406\u6bcf\u4e2a\u65f6\u95f4\u6b65\u7684\u8f93\u5165\uff0c\u8fd9\u4f7f\u5f97\u5b83\u4eec\u96be\u4ee5\u8fdb\u884c\u5e76\u884c\u8ba1\u7b97\uff0c\u4ece\u800c\u9650\u5236\u4e86\u6a21\u578b\u7684\u8bad\u7ec3\u901f\u5ea6\u548c\u5904\u7406\u957f\u5e8f\u5217\u7684\u80fd\u529b\u3002\u957f\u8ddd\u79bb\u4f9d\u8d56\u95ee\u9898\uff1a\u5728\u5904\u7406\u957f\u5e8f\u5217\u65f6\uff0cRNN \u53ca\u5176\u53d8\u4f53\u5bb9\u6613\u51fa\u73b0\u68af\u5ea6\u6d88\u5931\u6216\u68af\u5ea6\u7206\u70b8\u7684\u95ee\u9898\uff0c\u5bfc\u81f4\u6a21\u578b\u96be\u4ee5\u6355\u6349\u5e8f\u5217\u4e2d\u7684\u957f\u8ddd\u79bb\u4f9d\u8d56\u5173\u7cfb\u3002_ai\u7684transformer\u67b6\u6784\u63d0\u51fa\" \/>\n<meta property=\"og:url\" content=\"https:\/\/www.wsisp.com\/helps\/34017.html\" \/>\n<meta property=\"og:site_name\" content=\"\u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3\" \/>\n<meta property=\"article:published_time\" content=\"2025-04-28T14:09:38+00:00\" \/>\n<meta name=\"author\" content=\"admin\" \/>\n<meta name=\"twitter:card\" content=\"summary_large_image\" \/>\n<meta name=\"twitter:label1\" content=\"\u4f5c\u8005\" \/>\n\t<meta name=\"twitter:data1\" content=\"admin\" \/>\n\t<meta name=\"twitter:label2\" content=\"\u9884\u8ba1\u9605\u8bfb\u65f6\u95f4\" \/>\n\t<meta name=\"twitter:data2\" content=\"10 \u5206\" \/>\n<script type=\"application\/ld+json\" class=\"yoast-schema-graph\">{\"@context\":\"https:\/\/schema.org\",\"@graph\":[{\"@type\":\"WebPage\",\"@id\":\"https:\/\/www.wsisp.com\/helps\/34017.html\",\"url\":\"https:\/\/www.wsisp.com\/helps\/34017.html\",\"name\":\"AI \u5927\u6a21\u578b\u4e4b Transformer \u67b6\u6784\u6df1\u5165\u5256\u6790 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3\",\"isPartOf\":{\"@id\":\"https:\/\/www.wsisp.com\/helps\/#website\"},\"datePublished\":\"2025-04-28T14:09:38+00:00\",\"dateModified\":\"2025-04-28T14:09:38+00:00\",\"author\":{\"@id\":\"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/358e386c577a3ab51c4493330a20ad41\"},\"breadcrumb\":{\"@id\":\"https:\/\/www.wsisp.com\/helps\/34017.html#breadcrumb\"},\"inLanguage\":\"zh-Hans\",\"potentialAction\":[{\"@type\":\"ReadAction\",\"target\":[\"https:\/\/www.wsisp.com\/helps\/34017.html\"]}]},{\"@type\":\"BreadcrumbList\",\"@id\":\"https:\/\/www.wsisp.com\/helps\/34017.html#breadcrumb\",\"itemListElement\":[{\"@type\":\"ListItem\",\"position\":1,\"name\":\"\u9996\u9875\",\"item\":\"https:\/\/www.wsisp.com\/helps\"},{\"@type\":\"ListItem\",\"position\":2,\"name\":\"AI \u5927\u6a21\u578b\u4e4b Transformer \u67b6\u6784\u6df1\u5165\u5256\u6790\"}]},{\"@type\":\"WebSite\",\"@id\":\"https:\/\/www.wsisp.com\/helps\/#website\",\"url\":\"https:\/\/www.wsisp.com\/helps\/\",\"name\":\"\u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3\",\"description\":\"\u9999\u6e2f\u670d\u52a1\u5668_\u9999\u6e2f\u4e91\u670d\u52a1\u5668\u8d44\u8baf_\u670d\u52a1\u5668\u5e2e\u52a9\u6587\u6863_\u670d\u52a1\u5668\u6559\u7a0b\",\"potentialAction\":[{\"@type\":\"SearchAction\",\"target\":{\"@type\":\"EntryPoint\",\"urlTemplate\":\"https:\/\/www.wsisp.com\/helps\/?s={search_term_string}\"},\"query-input\":\"required name=search_term_string\"}],\"inLanguage\":\"zh-Hans\"},{\"@type\":\"Person\",\"@id\":\"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/358e386c577a3ab51c4493330a20ad41\",\"name\":\"admin\",\"image\":{\"@type\":\"ImageObject\",\"inLanguage\":\"zh-Hans\",\"@id\":\"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/image\/\",\"url\":\"https:\/\/gravatar.wp-china-yes.net\/avatar\/?s=96&d=mystery\",\"contentUrl\":\"https:\/\/gravatar.wp-china-yes.net\/avatar\/?s=96&d=mystery\",\"caption\":\"admin\"},\"sameAs\":[\"http:\/\/wp.wsisp.com\"],\"url\":\"https:\/\/www.wsisp.com\/helps\/author\/admin\"}]}<\/script>\n<!-- \/ Yoast SEO plugin. -->","yoast_head_json":{"title":"AI \u5927\u6a21\u578b\u4e4b Transformer \u67b6\u6784\u6df1\u5165\u5256\u6790 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3","robots":{"index":"index","follow":"follow","max-snippet":"max-snippet:-1","max-image-preview":"max-image-preview:large","max-video-preview":"max-video-preview:-1"},"canonical":"https:\/\/www.wsisp.com\/helps\/34017.html","og_locale":"zh_CN","og_type":"article","og_title":"AI \u5927\u6a21\u578b\u4e4b Transformer \u67b6\u6784\u6df1\u5165\u5256\u6790 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3","og_description":"\u6587\u7ae0\u6d4f\u89c8\u9605\u8bfb1.6k\u6b21\uff0c\u70b9\u8d5e75\u6b21\uff0c\u6536\u85cf37\u6b21\u3002\u5728 Transformer \u67b6\u6784\u51fa\u73b0\u4e4b\u524d\uff0c\u5faa\u73af\u795e\u7ecf\u7f51\u7edc\uff08RNN\uff09\u53ca\u5176\u53d8\u4f53\uff0c\u5982\u957f\u77ed\u671f\u8bb0\u5fc6\u7f51\u7edc\uff08LSTM\uff09\u548c\u95e8\u63a7\u5faa\u73af\u5355\u5143\uff08GRU\uff09\uff0c\u662f\u5904\u7406\u5e8f\u5217\u6570\u636e\u7684\u4e3b\u6d41\u6a21\u578b\u3002\u987a\u5e8f\u8ba1\u7b97\u95ee\u9898\uff1aRNN \u53ca\u5176\u53d8\u4f53\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u9700\u8981\u6309\u987a\u5e8f\u4f9d\u6b21\u5904\u7406\u6bcf\u4e2a\u65f6\u95f4\u6b65\u7684\u8f93\u5165\uff0c\u8fd9\u4f7f\u5f97\u5b83\u4eec\u96be\u4ee5\u8fdb\u884c\u5e76\u884c\u8ba1\u7b97\uff0c\u4ece\u800c\u9650\u5236\u4e86\u6a21\u578b\u7684\u8bad\u7ec3\u901f\u5ea6\u548c\u5904\u7406\u957f\u5e8f\u5217\u7684\u80fd\u529b\u3002\u957f\u8ddd\u79bb\u4f9d\u8d56\u95ee\u9898\uff1a\u5728\u5904\u7406\u957f\u5e8f\u5217\u65f6\uff0cRNN \u53ca\u5176\u53d8\u4f53\u5bb9\u6613\u51fa\u73b0\u68af\u5ea6\u6d88\u5931\u6216\u68af\u5ea6\u7206\u70b8\u7684\u95ee\u9898\uff0c\u5bfc\u81f4\u6a21\u578b\u96be\u4ee5\u6355\u6349\u5e8f\u5217\u4e2d\u7684\u957f\u8ddd\u79bb\u4f9d\u8d56\u5173\u7cfb\u3002_ai\u7684transformer\u67b6\u6784\u63d0\u51fa","og_url":"https:\/\/www.wsisp.com\/helps\/34017.html","og_site_name":"\u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3","article_published_time":"2025-04-28T14:09:38+00:00","author":"admin","twitter_card":"summary_large_image","twitter_misc":{"\u4f5c\u8005":"admin","\u9884\u8ba1\u9605\u8bfb\u65f6\u95f4":"10 \u5206"},"schema":{"@context":"https:\/\/schema.org","@graph":[{"@type":"WebPage","@id":"https:\/\/www.wsisp.com\/helps\/34017.html","url":"https:\/\/www.wsisp.com\/helps\/34017.html","name":"AI \u5927\u6a21\u578b\u4e4b Transformer \u67b6\u6784\u6df1\u5165\u5256\u6790 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3","isPartOf":{"@id":"https:\/\/www.wsisp.com\/helps\/#website"},"datePublished":"2025-04-28T14:09:38+00:00","dateModified":"2025-04-28T14:09:38+00:00","author":{"@id":"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/358e386c577a3ab51c4493330a20ad41"},"breadcrumb":{"@id":"https:\/\/www.wsisp.com\/helps\/34017.html#breadcrumb"},"inLanguage":"zh-Hans","potentialAction":[{"@type":"ReadAction","target":["https:\/\/www.wsisp.com\/helps\/34017.html"]}]},{"@type":"BreadcrumbList","@id":"https:\/\/www.wsisp.com\/helps\/34017.html#breadcrumb","itemListElement":[{"@type":"ListItem","position":1,"name":"\u9996\u9875","item":"https:\/\/www.wsisp.com\/helps"},{"@type":"ListItem","position":2,"name":"AI \u5927\u6a21\u578b\u4e4b Transformer \u67b6\u6784\u6df1\u5165\u5256\u6790"}]},{"@type":"WebSite","@id":"https:\/\/www.wsisp.com\/helps\/#website","url":"https:\/\/www.wsisp.com\/helps\/","name":"\u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3","description":"\u9999\u6e2f\u670d\u52a1\u5668_\u9999\u6e2f\u4e91\u670d\u52a1\u5668\u8d44\u8baf_\u670d\u52a1\u5668\u5e2e\u52a9\u6587\u6863_\u670d\u52a1\u5668\u6559\u7a0b","potentialAction":[{"@type":"SearchAction","target":{"@type":"EntryPoint","urlTemplate":"https:\/\/www.wsisp.com\/helps\/?s={search_term_string}"},"query-input":"required name=search_term_string"}],"inLanguage":"zh-Hans"},{"@type":"Person","@id":"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/358e386c577a3ab51c4493330a20ad41","name":"admin","image":{"@type":"ImageObject","inLanguage":"zh-Hans","@id":"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/image\/","url":"https:\/\/gravatar.wp-china-yes.net\/avatar\/?s=96&d=mystery","contentUrl":"https:\/\/gravatar.wp-china-yes.net\/avatar\/?s=96&d=mystery","caption":"admin"},"sameAs":["http:\/\/wp.wsisp.com"],"url":"https:\/\/www.wsisp.com\/helps\/author\/admin"}]}},"_links":{"self":[{"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/posts\/34017","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/users\/2"}],"replies":[{"embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/comments?post=34017"}],"version-history":[{"count":0,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/posts\/34017\/revisions"}],"wp:attachment":[{"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/media?parent=34017"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/categories?post=34017"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/tags?post=34017"},{"taxonomy":"topic","embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/topic?post=34017"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}