{"id":33808,"date":"2025-04-28T15:49:45","date_gmt":"2025-04-28T07:49:45","guid":{"rendered":"https:\/\/www.wsisp.com\/helps\/33808.html"},"modified":"2025-04-28T15:49:45","modified_gmt":"2025-04-28T07:49:45","slug":"%e6%ba%90%e7%a0%81%e8%a7%a3%e6%9e%90%ef%bc%9a%e4%bb%8e%e9%9b%b6%e8%a7%a3%e8%af%bbsamsegment-anything-model%e5%a4%a7%e6%a8%a1%e5%9e%8b%ef%bc%81","status":"publish","type":"post","link":"https:\/\/www.wsisp.com\/helps\/33808.html","title":{"rendered":"\u6e90\u7801\u89e3\u6790\uff1a\u4ece\u96f6\u89e3\u8bfbSAM(Segment Anything Model)\u5927\u6a21\u578b\uff01"},"content":{"rendered":"<p>\u8282\u524d&#xff0c;\u6211\u4eec\u661f\u7403\u7ec4\u7ec7\u4e86\u4e00\u573a\u7b97\u6cd5\u5c97\u6280\u672f&amp;\u9762\u8bd5\u8ba8\u8bba\u4f1a&#xff0c;\u9080\u8bf7\u4e86\u4e00\u4e9b\u4e92\u8054\u7f51\u5927\u5382\u670b\u53cb\u3001\u53c2\u52a0\u793e\u62db\u548c\u6821\u62db\u9762\u8bd5\u7684\u540c\u5b66\u3002<\/p>\n<p>\u9488\u5bf9\u7b97\u6cd5\u5c97\u6280\u672f\u8d8b\u52bf\u3001\u5927\u6a21\u578b\u843d\u5730\u9879\u76ee\u7ecf\u9a8c\u5206\u4eab\u3001\u65b0\u624b\u5982\u4f55\u5165\u95e8\u7b97\u6cd5\u5c97\u3001\u8be5\u5982\u4f55\u51c6\u5907\u3001\u9762\u8bd5\u5e38\u8003\u70b9\u5206\u4eab\u7b49\u70ed\u95e8\u8bdd\u9898\u8fdb\u884c\u4e86\u6df1\u5165\u7684\u8ba8\u8bba\u3002<\/p>\n<p>\u5408\u96c6&#xff1a;<\/p>\n<p>\u300a\u5927\u6a21\u578b\u9762\u8bd5\u5b9d\u5178\u300b(2024\u7248) \u6b63\u5f0f\u53d1\u5e03&#xff01;<\/p>\n<p>\u6301\u7eed\u706b\u7206&#xff01;&#xff01;&#xff01;\u300aAIGC \u9762\u8bd5\u5b9d\u5178\u300b\u5df2\u5708\u7c89\u65e0\u6570&#xff01;<\/p>\n<hr \/>\n<p>SAM(Segment Anything Model)&#xff0c;\u987e\u540d\u601d\u4e49&#xff0c;\u5373\u4e3a\u5206\u5272\u4e00\u5207&#xff01;\u8be5\u6a21\u578b\u7531Facebook\u7684Meta AI\u5b9e\u9a8c\u5ba4&#xff0c;\u80fd\u591f\u6839\u636e\u6587\u672c\u6307\u4ee4\u6216\u56fe\u50cf\u8bc6\u522b&#xff0c;\u5b9e\u73b0\u5bf9\u4efb\u610f\u7269\u4f53\u7684\u8bc6\u522b\u4e0e\u5206\u5272\u3002\u5b83\u7684\u8bde\u751f&#xff0c;\u65e0\u7591\u662fCV\u9886\u57df\u7684\u4e00\u6b21\u91cd\u8981\u91cc\u7a0b\u7891\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074938-680f3312d5fa7.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p>\u8bba\u6587\u5730\u5740&#xff1a;https:\/\/arxiv.org\/abs\/2304.02643 \u9879\u76ee\u5730\u5740&#xff1a;https:\/\/github.com\/facebookresearch\/segment-anything<\/p>\n<h3>SAM Task<\/h3>\n<p>SAM\u501f\u9274\u4e86NLP\u9886\u57df\u7684Prompt\u7b56\u7565&#xff0c;\u901a\u8fc7\u7ed9\u56fe\u50cf\u5206\u5272\u4efb\u52a1\u63d0\u4f9bPrompt\u63d0\u793a\u6765\u5b8c\u6210\u4efb\u610f\u76ee\u6807\u7684\u5feb\u901f\u5206\u5272\u3002Prompt\u7c7b\u578b\u53ef\u4ee5\u662f**\u300c\u524d\u666f\/\u80cc\u666f\u70b9\u96c6\u3001\u7c97\u7565\u7684\u6846\u6216\u906e\u7f69\u3001\u4efb\u610f\u5f62\u5f0f\u7684\u6587\u672c\u6216\u8005\u4efb\u4f55\u6307\u793a\u56fe\u50cf\u4e2d\u9700\u8981\u8fdb\u884c\u5206\u5272\u300d**\u7684\u4fe1\u606f\u3002\u5982\u4e0b\u56fe(a)\u6240\u793a&#xff0c;\u6a21\u578b\u7684\u8f93\u5165\u662f\u539f\u59cb\u7684\u56fe\u50cf\u548c\u4e00\u4e9bprompt&#xff0c;\u76ee\u6807\u662f\u8f93\u51fa&#034;valid&#034;\u7684\u5206\u5272&#xff0c;\u6240\u8c13valid&#xff0c;\u5c31\u662f\u5f53prompt\u7684\u6307\u5411\u662f\u6a21\u7cca\u65f6&#xff0c;\u6a21\u578b\u80fd\u591f\u8f93\u51fa\u81f3\u5c11\u5176\u4e2d\u4e00\u4e2amask\u3002<\/p>\n<p>\u8fd9\u6837&#xff0c;\u53ef\u4ee5\u662f\u7684SAM\u80fd\u591f\u9002\u914d\u5404\u79cd\u4e0b\u6e38\u4efb\u52a1\u3002\u4f8b\u5982&#xff0c;\u7ed9\u5b9a\u4e00\u4e2a\u732b\u7684\u8fb9\u754c\u6846&#xff0c;SAM\u80fd\u591f\u8f93\u51fa\u5176mask&#xff0c;\u4ece\u800c\u548c\u5b9e\u4f8b\u5206\u5272\u4efb\u52a1\u642d\u914d\u8d77\u6765\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074939-680f331372ab4.png\" alt=\"\u56fe\u7247\" \/><\/p>\n<h3>SAM Model<\/h3>\n<p>\u5982\u4e0b\u56fe\u6240\u793a&#xff0c;SAM\u6a21\u578b\u5305\u542b\u4e09\u4e2a\u6838\u5fc3\u7ec4\u4ef6&#xff0c;Image Encoder\u3001Prompt Encoder\u548cMask Decoder\u3002\u56fe\u50cf\u7ecf\u8fc7Image Encoder\u7f16\u7801&#xff0c;Prompt\u63d0\u793a\u7ecf\u8fc7Prompt Encoder\u7f16\u7801&#xff0c;\u4e24\u90e8\u5206Embedding\u518d\u7ecf\u8fc7\u4e00\u4e2a\u8f7b\u91cf\u5316\u7684Mask Decoder\u5f97\u5230\u878d\u5408\u540e\u7684\u7279\u5f81\u3002\u5176\u4e2d&#xff0c;Encoder\u90e8\u5206\u4f7f\u7528\u7684\u662f\u5df2\u6709\u6a21\u578b&#xff0c;Decoder\u90e8\u5206\u4f7f\u7528Transformer\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074939-680f3313b7b1d.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<h4>Image Encoder<\/h4>\n<p>Image Encoder\u7684\u4f5c\u7528\u662f\u628a\u56fe\u50cf\u6620\u5c04\u5230\u7279\u5f81\u7a7a\u95f4&#xff0c;\u6574\u4f53\u8fc7\u7a0b\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074940-680f331407e89.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p>\u6b63\u5982\u8bba\u6587\u4e2d\u6240\u8bb2&#xff0c;\u672c\u8d28\u4e0a\u8fd9\u4e2aEncoder\u53ef\u4ee5\u662f\u4efb\u4f55\u7f51\u7edc\u7ed3\u6784&#xff0c;\u5728\u8fd9\u91cc\u4f7f\u7528\u7684\u662f\u5fae\u8c03\u7684Detectron\u7684ViT&#xff0c;\u5f53\u7136\u5b83\u4e5f\u53ef\u4ee5\u88ab\u6539\u6210\u4f20\u7edf\u7684\u5377\u79ef\u7ed3\u6784&#xff0c;\u975e\u5e38\u5408\u7406\u3002<\/p>\n<p>\u8f93\u5165\u56fe\u50cf\u7ecf\u8fc7ViT\u7ed3\u6784\u7684\u8fc7\u7a0b\u5982\u4e0b&#xff1a;<\/p>\n<h5>1. Patch Embedding<\/h5>\n<p>\u8f93\u5165\u56fe\u50cf\u901a\u8fc7\u4e00\u4e2a\u5377\u79efbase&#xff0c;\u5c06\u56fe\u50cf\u5212\u5206\u4e3a16&#215;16\u7684patches&#xff0c;\u6b65\u957f\u4e5f\u4e3a16&#xff0c;\u8fd9\u6837feature map\u7684\u5c3a\u5bf8\u5c31\u7f29\u5c0f\u4e8616\u500d&#xff0c;\u540c\u65f6channel\u4ece3\u6620\u5c04\u5230768\u3002Patch Embedding\u793a\u610f\u56fe\u5982\u4e0b\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074940-680f33148929b.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p>\u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p><span class=\"token triple-quoted-string string\">&#039;&#039;&#039;<br \/>\n\u5c06\u8f93\u5165\u7684\u56fe\u50cf\u8f6c\u6362\u4e3a\u5e8f\u5217\u5316\u7684\u7279\u5f81\u5411\u91cf<br \/>\n&#039;&#039;&#039;<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">PatchEmbed<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><br \/>\n        self<span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u5377\u79ef\u6838\u5927\u5c0f<\/span><br \/>\n        <span class=\"token comment\"># \u8fd9\u91cc\u662f (16, 16)&#xff0c;\u610f\u5473\u7740\u56fe\u50cf\u5c06\u88ab\u5212\u5206\u4e3a16&#215;16\u7684patches<\/span><br \/>\n        kernel_size<span class=\"token punctuation\">:<\/span> Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token number\">16<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">16<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u5377\u79ef\u7684\u6b65\u957f&#xff0c;\u4e0ekernel_size\u76f8\u540c&#xff0c;\u5373(16, 16)&#xff0c;<\/span><br \/>\n        <span class=\"token comment\"># \u610f\u5473\u7740\u6bcf\u4e00\u6b65\u79fb\u52a816\u4e2a\u50cf\u7d20&#xff0c;\u8fd9\u6837\u56fe\u50cf\u7684\u5c3a\u5bf8\u5c31\u4f1a\u51cf\u5c11\u5230\u539f\u6765\u76841\/16<\/span><br \/>\n        stride<span class=\"token punctuation\">:<\/span> Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token number\">16<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">16<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u63a7\u5236\u8fb9\u7f18\u586b\u5145&#xff0c;\u8fd9\u91cc\u8bbe\u7f6e\u4e3a (0, 0)&#xff0c;\u610f\u5473\u7740\u6ca1\u6709\u989d\u5916\u7684\u586b\u5145<\/span><br \/>\n        padding<span class=\"token punctuation\">:<\/span> Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u8f93\u5165\u56fe\u50cf\u7684\u901a\u9053\u6570&#xff0c;\u901a\u5e38\u4e3a3&#xff08;RGB\u56fe\u50cf&#xff09;<\/span><br \/>\n        in_chans<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u8f93\u51fa\u7684\u7279\u5f81\u7ef4\u5ea6&#xff0c;\u4e5f\u5c31\u662f\u6bcf\u4e2apatch\u88ab\u7f16\u7801\u4e3a\u7684\u5411\u91cf\u7684\u957f\u5ea6&#xff0c;\u8fd9\u91cc\u8bbe\u7f6e\u4e3a768<\/span><br \/>\n        embed_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">768<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token triple-quoted-string string\">&#039;&#039;&#039;<br \/>\n        \u521d\u59cb\u5316\u8fd9\u4e2a\u5b50\u7c7b\u5b9e\u4f8b\u7684\u5c5e\u6027<br \/>\n        &#039;&#039;&#039;<\/span><br \/>\n        <span class=\"token comment\"># PatchEmbed\u7684\u5b50\u7c7b&#xff0c;\u7ee7\u627f\u81eann.Module&#xff0c;\u7528\u4e8e\u6784\u5efa\u795e\u7ecf\u7f51\u7edc\u6a21\u5757<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>proj <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Conv2d<span class=\"token punctuation\">(<\/span><br \/>\n            in_chans<span class=\"token punctuation\">,<\/span> embed_dim<span class=\"token punctuation\">,<\/span> kernel_size<span class=\"token operator\">&#061;<\/span>kernel_size<span class=\"token punctuation\">,<\/span> stride<span class=\"token operator\">&#061;<\/span>stride<span class=\"token punctuation\">,<\/span> padding<span class=\"token operator\">&#061;<\/span>padding<br \/>\n        <span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token triple-quoted-string string\">&#039;&#039;&#039;\u524d\u5411\u4f20\u64ad&#xff1a;<br \/>\n       \u63a5\u6536\u8f93\u5165\u5f20\u91cf x&#xff0c;\u5f62\u72b6 (B, C, H, W)&#xff0c;\u5176\u4e2d&#xff0c;<br \/>\n       &#8211; B\u8868\u793a\u6279\u6b21\u5927\u5c0f<br \/>\n       &#8211; C \u662f\u8f93\u5165\u901a\u9053\u6570<br \/>\n       &#8211; H \u548c W \u662f\u56fe\u50cf\u7684\u9ad8\u5ea6\u548c\u5bbd\u5ea6<br \/>\n    &#039;&#039;&#039;<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u5377\u79ef&#xff0c;\u5c06\u8f93\u5165\u7684\u901a\u9053\u6570\u4ece in_chans \u8f6c\u6362\u4e3a embed_dim<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>proj<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5c06\u5f20\u91cf\u7684\u7ef4\u5ea6\u987a\u5e8f\u4ece (B, C, H, W) \u8c03\u6574\u4e3a (B, H, W, C)<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>permute<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> x<\/p>\n<p>Patch Embedding\u8fc7\u7a0b\u5728Vision Transformer\u7ed3\u6784\u56fe\u4e2d\u5bf9\u5e94\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074940-680f3314e02df.png\" alt=\"\u56fe\u7247\" \/><\/p>\n<h5>2. Positiona Embedding<\/h5>\n<p>\u7ecf\u8fc7Patch Embedding\u540e\u8f93\u51fatokens\u9700\u8981\u52a0\u5165\u4f4d\u7f6e\u7f16\u7801&#xff0c;\u4ee5\u4fdd\u7559\u56fe\u50cf\u7684\u7a7a\u95f4\u4fe1\u606f\u3002\u4f4d\u7f6e\u7f16\u7801\u53ef\u4ee5\u7406\u89e3\u4e3a\u4e00\u5f20map&#xff0c;map\u7684\u884c\u6570\u4e0e\u8f93\u5165\u5e8f\u5217\u4e2a\u6570\u76f8\u540c&#xff0c;\u6bcf\u4e00\u884c\u4ee3\u8868\u4e00\u4e2a\u5411\u91cf&#xff0c;\u5411\u91cf\u7684\u7ef4\u5ea6\u548c\u8f93\u5165\u5e8f\u5217tokens\u7684\u7ef4\u5ea6\u76f8\u540c&#xff0c;\u4f4d\u7f6e\u7f16\u7801\u7684\u64cd\u4f5c\u662fsum&#xff0c;\u6240\u4ee5\u7ef4\u5ea6\u4f9d\u65e7\u4fdd\u6301\u4e0d\u53d8\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074941-680f331522a6b.png\" alt=\"\u56fe\u7247\" \/><\/p>\n<p>\u56fe\u50cf\u5c3a\u5bf8\u662f1024&#xff0c;\u56e0\u6b64patch\u7684\u6570\u91cf\u662f1024\/16&#061;64\u3002<\/p>\n<p>\u4ee3\u7801\u5b9e\u73b0&#xff1a;<\/p>\n<p><span class=\"token comment\"># \u5728ImageEncoderViT\u7684__init__\u5b9a\u4e49<\/span><br \/>\n<span class=\"token keyword\">if<\/span> use_abs_pos<span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u4f7f\u7528\u9884\u8bad\u7ec3\u56fe\u50cf\u5927\u5c0f\u521d\u59cb\u5316\u7edd\u5bf9\u4f4d\u7f6e\u5d4c\u5165<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>pos_embed <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Parameter<span class=\"token punctuation\">(<\/span><br \/>\n        torch<span class=\"token punctuation\">.<\/span>zeros<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> img_size <span class=\"token operator\">\/\/<\/span> patch_size<span class=\"token punctuation\">,<\/span> img_size <span class=\"token operator\">\/\/<\/span> patch_size<span class=\"token punctuation\">,<\/span> embed_dim<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span><br \/>\n<span class=\"token comment\"># \u5728ImageEncoderViT\u7684forward\u6dfb\u52a0\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n<span class=\"token keyword\">if<\/span> self<span class=\"token punctuation\">.<\/span>pos_embed <span class=\"token keyword\">is<\/span> <span class=\"token keyword\">not<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    x <span class=\"token operator\">&#061;<\/span> x <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>pos_embed<\/p>\n<p>Positiona Embedding\u8fc7\u7a0b\u5728\u7ed3\u6784\u56fe\u4e2d\u5bf9\u5e94\u7684\u90e8\u5206&#xff1a;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074941-680f331554f70.png\" alt=\"\u56fe\u7247\" \/><\/p>\n<h5>3. Transformer Encoder<\/h5>\n<p>feature map\u901a\u8fc716\u4e2aTransformer Block&#xff0c;\u5176\u4e2d12\u4e2aBlock\u4f7f\u7528\u4e86\u57fa\u4e8eWindow Partition&#xff08;\u5c31\u662f\u628a\u7279\u5f81\u56fe\u5206\u621014*14\u7684windows\u505a\u5c40\u90e8\u7684Attention&#xff09;\u7684\u6ce8\u610f\u529b\u673a\u5236&#xff0c;\u4ee5\u5904\u7406\u5c40\u90e8\u4fe1\u606f\u3002\u53e6\u59164\u4e2aBlock\u662f\u5168\u5c40\u6ce8\u610f\u529b\u6a21\u5757&#xff0c;\u5b83\u4eec\u7a7f\u63d2\u5728Window Partition\u6a21\u5757\u4e4b\u95f4&#xff0c;\u4ee5\u6355\u6349\u56fe\u50cf\u7684\u5168\u5c40\u4e0a\u4e0b\u6587\u3002<\/p>\n<p><span class=\"token comment\"># \u5728ImageEncoderViT\u7684__init__\u5b9a\u4e49<\/span><br \/>\n<span class=\"token comment\"># &#8212;&#8211;Transformer Encoder&#8212;&#8211;<\/span><br \/>\n<span class=\"token comment\"># \u521d\u59cb\u5316\u4e00\u4e2aModuleList&#xff0c;\u7528\u4e8e\u5b58\u50a8Block\u5b9e\u4f8b<\/span><br \/>\nself<span class=\"token punctuation\">.<\/span>blocks <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>ModuleList<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n<span class=\"token comment\"># \u5faa\u73af\u521b\u5efaBlock&#xff0c;depth\u662fTransformer Encoder\u5c42\u6570<\/span><br \/>\n<span class=\"token keyword\">for<\/span> i <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span>depth<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u521b\u5efa\u5355\u4e2aBlock<\/span><br \/>\n    block <span class=\"token operator\">&#061;<\/span> Block<span class=\"token punctuation\">(<\/span><br \/>\n        <span class=\"token comment\"># \u8f93\u5165\u7684\u901a\u9053\u6570&#xff0c;\u5373\u6bcf\u4e2apatch\u7f16\u7801\u540e\u7684\u5411\u91cf\u7ef4\u5ea6<\/span><br \/>\n        dim<span class=\"token operator\">&#061;<\/span>embed_dim<span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u81ea\u6ce8\u610f\u529b\u673a\u5236\u4e2d\u7684\u6ce8\u610f\u529b\u5934\u6570<\/span><br \/>\n        num_heads<span class=\"token operator\">&#061;<\/span>num_heads<span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># MLP\u5c42\u7684\u901a\u9053\u6570\u76f8\u5bf9\u4e8e\u8f93\u5165\u901a\u9053\u6570\u7684\u6bd4\u4f8b<\/span><br \/>\n        mlp_ratio<span class=\"token operator\">&#061;<\/span>mlp_ratio<span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u662f\u5426\u5728QKV\u5168\u8fde\u63a5\u5c42\u4e2d\u4f7f\u7528\u504f\u7f6e<\/span><br \/>\n        qkv_bias<span class=\"token operator\">&#061;<\/span>qkv_bias<span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u5f52\u4e00\u5316\u5c42<\/span><br \/>\n        norm_layer<span class=\"token operator\">&#061;<\/span>norm_layer<span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u6fc0\u6d3b\u51fd\u6570<\/span><br \/>\n        act_layer<span class=\"token operator\">&#061;<\/span>act_layer<span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u662f\u5426\u4f7f\u7528\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n        use_rel_pos<span class=\"token operator\">&#061;<\/span>use_rel_pos<span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u7684\u521d\u59cb\u5316\u8bbe\u7f6e<\/span><br \/>\n        rel_pos_zero_init<span class=\"token operator\">&#061;<\/span>rel_pos_zero_init<span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u5982\u679c\u5f53\u524dBlock\u4e0d\u662f\u5168\u5c40\u6ce8\u610f\u529b\u5c42&#xff0c;\u5219\u4f7f\u7528\u7a97\u53e3\u5927\u5c0f&#xff0c;\u5426\u5219\u4f7f\u75280<\/span><br \/>\n        window_size<span class=\"token operator\">&#061;<\/span>window_size <span class=\"token keyword\">if<\/span> i <span class=\"token keyword\">not<\/span> <span class=\"token keyword\">in<\/span> global_attn_indexes <span class=\"token keyword\">else<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u8f93\u5165\u7279\u5f81\u7684\u5c3a\u5bf8&#xff0c;\u57fa\u4e8e\u539f\u59cb\u56fe\u50cf\u5927\u5c0f\u548cpatch\u5927\u5c0f\u8ba1\u7b97\u5f97\u51fa<\/span><br \/>\n        input_size<span class=\"token operator\">&#061;<\/span><span class=\"token punctuation\">(<\/span>img_size <span class=\"token operator\">\/\/<\/span> patch_size<span class=\"token punctuation\">,<\/span> img_size <span class=\"token operator\">\/\/<\/span> patch_size<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u521b\u5efa\u7684Block\u5bf9\u8c61\u6dfb\u52a0\u5230self.blocks\u5217\u8868\u4e2d<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>blocks<span class=\"token punctuation\">.<\/span>append<span class=\"token punctuation\">(<\/span>block<span class=\"token punctuation\">)<\/span><br \/>\n<span class=\"token comment\"># &#8212;&#8211;Transformer Encoder&#8212;&#8211;<\/span><\/p>\n<p>Transformer Encoder\u8fc7\u7a0b\u5728\u7ed3\u6784\u56fe\u4e2d\u5bf9\u5e94\u7684\u90e8\u5206&#xff1a;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074941-680f33159bb66.png\" alt=\"\u56fe\u7247\" \/><\/p>\n<p>Encoder Block<\/p>\n<p>\u5982\u4e0a\u56fe\u53f3\u6240\u793a&#xff0c;Encoder Block\u4ece\u4f4e\u5230\u9ad8\u4e3b\u8981\u7531LayerNorm \u3001Multi-Head Attention\u548cMLP\u6784\u6210\u3002<\/p>\n<p><span class=\"token keyword\">class<\/span> <span class=\"token class-name\">Block<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><br \/>\n        self<span class=\"token punctuation\">,<\/span><br \/>\n        dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>                           <span class=\"token comment\"># \u8f93\u5165\u901a\u9053\u6570<\/span><br \/>\n        num_heads<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>                     <span class=\"token comment\"># attention\u4e2dhead\u7684\u4e2a\u6570<\/span><br \/>\n        mlp_ratio<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">float<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">4.0<\/span><span class=\"token punctuation\">,<\/span>             <span class=\"token comment\"># MLP\u5c42\u7684\u901a\u9053\u6570\u76f8\u5bf9\u4e8e\u8f93\u5165\u901a\u9053\u6570\u7684\u6bd4\u4f8b\u3002<\/span><br \/>\n        qkv_bias<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">bool<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">True<\/span><span class=\"token punctuation\">,<\/span>              <span class=\"token comment\"># \u5982\u679c\u4e3aTrue&#xff0c;QKV\u5168\u8fde\u63a5\u5c42\u5305\u542b\u504f\u7f6e\u3002<\/span><br \/>\n        norm_layer<span class=\"token punctuation\">:<\/span> Type<span class=\"token punctuation\">[<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>LayerNorm<span class=\"token punctuation\">,<\/span>     <span class=\"token comment\"># \u5f52\u4e00\u5316\u5c42<\/span><br \/>\n        act_layer<span class=\"token punctuation\">:<\/span> Type<span class=\"token punctuation\">[<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>GELU<span class=\"token punctuation\">,<\/span>           <span class=\"token comment\"># \u6fc0\u6d3b\u5c42<\/span><br \/>\n        use_rel_pos<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">bool<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">False<\/span><span class=\"token punctuation\">,<\/span>                      <span class=\"token comment\"># \u662f\u5426\u4f7f\u7528\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n        rel_pos_zero_init<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">bool<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">True<\/span><span class=\"token punctuation\">,<\/span>                 <span class=\"token comment\"># \u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u7684\u521d\u59cb\u5316\u8bbe\u7f6e<\/span><br \/>\n        window_size<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span>                           <span class=\"token comment\"># \u6ce8\u610f\u529b\u5c42\u7684\u7a97\u53e3\u5927\u5c0f<\/span><br \/>\n        input_size<span class=\"token punctuation\">:<\/span> Optional<span class=\"token punctuation\">[<\/span>Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">,<\/span>   <span class=\"token comment\"># \u8f93\u5165\u7279\u5f81\u7684\u5c3a\u5bf8<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>norm1 <span class=\"token operator\">&#061;<\/span> norm_layer<span class=\"token punctuation\">(<\/span>dim<span class=\"token punctuation\">)<\/span>         <span class=\"token comment\"># \u7b2c\u4e00\u4e2a\u5f52\u4e00\u5316\u5c42&#xff0c;\u7528\u4e8e\u6ce8\u610f\u529b\u5c42<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>attn <span class=\"token operator\">&#061;<\/span> Attention<span class=\"token punctuation\">(<\/span>               <span class=\"token comment\"># Multi-Head Attention<\/span><br \/>\n            dim<span class=\"token punctuation\">,<\/span><br \/>\n            num_heads<span class=\"token operator\">&#061;<\/span>num_heads<span class=\"token punctuation\">,<\/span><br \/>\n            qkv_bias<span class=\"token operator\">&#061;<\/span>qkv_bias<span class=\"token punctuation\">,<\/span><br \/>\n            use_rel_pos<span class=\"token operator\">&#061;<\/span>use_rel_pos<span class=\"token punctuation\">,<\/span><br \/>\n            rel_pos_zero_init<span class=\"token operator\">&#061;<\/span>rel_pos_zero_init<span class=\"token punctuation\">,<\/span><br \/>\n            input_size<span class=\"token operator\">&#061;<\/span>input_size <span class=\"token keyword\">if<\/span> window_size <span class=\"token operator\">&#061;&#061;<\/span> <span class=\"token number\">0<\/span> <span class=\"token keyword\">else<\/span> <span class=\"token punctuation\">(<\/span>window_size<span class=\"token punctuation\">,<\/span> window_size<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>norm2 <span class=\"token operator\">&#061;<\/span> norm_layer<span class=\"token punctuation\">(<\/span>dim<span class=\"token punctuation\">)<\/span>      <span class=\"token comment\">#\u7b2c\u4e8c\u4e2a\u5f52\u4e00\u5316\u5c42&#xff0c;\u7528\u4e8eMLP\u4e4b\u524d<\/span><br \/>\n        <span class=\"token comment\"># MLP<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>mlp <span class=\"token operator\">&#061;<\/span> MLPBlock<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token operator\">&#061;<\/span>dim<span class=\"token punctuation\">,<\/span> mlp_dim<span class=\"token operator\">&#061;<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">(<\/span>dim <span class=\"token operator\">*<\/span> mlp_ratio<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> act<span class=\"token operator\">&#061;<\/span>act_layer<span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>window_size <span class=\"token operator\">&#061;<\/span> window_size<br \/>\n    <span class=\"token comment\"># \u524d\u5411\u4f20\u64ad<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u4fdd\u5b58\u8f93\u5165\u5f20\u91cf\u7684\u526f\u672c<\/span><br \/>\n        shortcut <span class=\"token operator\">&#061;<\/span> x<br \/>\n        <span class=\"token comment\"># \u5bf9\u8f93\u5165\u5f20\u91cf\u5e94\u7528\u7b2c\u4e00\u4e2a\u5f52\u4e00\u5316\u5c42<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>norm1<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># Window partition \u5bf9X\u8fdb\u884cpadding<\/span><br \/>\n        <span class=\"token keyword\">if<\/span> self<span class=\"token punctuation\">.<\/span>window_size <span class=\"token operator\">&gt;<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">:<\/span><br \/>\n            H<span class=\"token punctuation\">,<\/span> W <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">[<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">[<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">]<\/span><br \/>\n            x<span class=\"token punctuation\">,<\/span> pad_hw <span class=\"token operator\">&#061;<\/span> window_partition<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>window_size<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># Multi-Head Attention<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>attn<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5982\u679c window_size &gt; 0&#xff0c;\u4f7f\u7528window_unpartition\u53bb\u9664\u7a97\u53e3\u5206\u533a\u7684padding&#xff0c;\u6062\u590d\u539f\u59cb\u5c3a\u5bf8<\/span><br \/>\n        <span class=\"token keyword\">if<\/span> self<span class=\"token punctuation\">.<\/span>window_size <span class=\"token operator\">&gt;<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">:<\/span><br \/>\n            x <span class=\"token operator\">&#061;<\/span> window_unpartition<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>window_size<span class=\"token punctuation\">,<\/span> pad_hw<span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">(<\/span>H<span class=\"token punctuation\">,<\/span> W<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5c06\u6ce8\u610f\u529b\u5c42\u7684\u8f93\u51fa\u4e0e\u8f93\u5165\u5f20\u91cf\u76f8\u52a0&#xff0c;\u5b9e\u73b0\u6b8b\u5dee\u8fde\u63a5<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> shortcut <span class=\"token operator\">&#043;<\/span> x<br \/>\n        <span class=\"token comment\"># \u5bf9\u7ecf\u8fc7\u7b2c\u4e8c\u4e2a\u5f52\u4e00\u5316\u5c42\u7684\u5f20\u91cf\u5e94\u7528MLP\u5c42&#xff0c;\u518d\u6b21\u4f7f\u7528\u6b8b\u5dee\u8fde\u63a5<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> x <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>mlp<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>norm2<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u8fd4\u56de\u6700\u7ec8\u7684\u5f20\u91cf x<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> x<\/p>\n<p>Partition\u64cd\u4f5c<\/p>\n<p>\u5728\u975e\u5168\u5c40\u6ce8\u610f\u529b\u7684Block\u4e2d&#xff0c;\u4e3a\u4e86\u9002\u5e9414&#215;14\u7684\u7a97\u53e3\u5927\u5c0f&#xff0c;\u8f93\u5165\u7279\u5f81\u56fe\u9700\u8981\u8fdb\u884c\u8865\u8fb9&#xff08;padding&#xff09;\u548c\u62c6\u5206\u64cd\u4f5c\u3002\u5177\u4f53\u6d41\u7a0b\u5982\u4e0b&#xff1a;<\/p>\n<li>\n<p>\u8f93\u5165\u7279\u5f81\u56fe&#xff1a;\u8f93\u5165\u7279\u5f81\u56fe\u7684\u521d\u59cb\u5c3a\u5bf8\u4e3a 1x64x64x768\u3002<\/p>\n<\/li>\n<li>\n<p>\u786e\u5b9a\u6700\u5c0f\u53ef\u6574\u9664\u5c3a\u5bf8&#xff1a;\u7a97\u53e3\u5927\u5c0f\u4e3a14*14&#xff0c;\u8981\u627e\u5230\u80fd\u591f\u88ab14\u6574\u9664\u7684\u6700\u5c0f\u7279\u5f81\u56fe\u5c3a\u5bf8\u3002\u5bf9\u4e8e\u5bbd\u5ea6\u548c\u9ad8\u5ea6&#xff0c;\u6211\u4eec\u9700\u8981\u627e\u5230\u5927\u4e8e\u7b49\u4e8e64\u4e14\u80fd\u88ab14\u6574\u9664\u7684\u6700\u5c0f\u6570\u3002\u8fd9\u4e24\u4e2a\u6570\u5206\u522b\u662f70&#xff08;64&#043;6&#xff09;\u548c70&#xff08;64&#043;6&#xff09;&#xff0c;\u6240\u4ee5\u6700\u5c0f\u53ef\u6574\u9664\u7279\u5f81\u56fe\u7684\u5c3a\u5bf8\u662f 1x70x70x768\u3002<\/p>\n<\/li>\n<li>\n<p>padding&#xff1a;\u4e3a\u4e86\u5c06\u7279\u5f81\u56fe\u5c3a\u5bf8\u4ece 64&#215;64 \u6269\u5c55\u5230 70&#215;70&#xff0c;\u6211\u4eec\u9700\u8981\u5728\u53f3\u4e0b\u89d2\u586b\u5145 6&#215;6 \u7684\u533a\u57df&#xff0c;\u56e0\u4e3a70-64&#061;6\u3002\u8fd9\u79cdpadding\u65b9\u5f0f\u786e\u4fdd\u4e86\u7a97\u53e3\u53ef\u4ee5\u5728\u7279\u5f81\u56fe\u7684\u8fb9\u7f18\u6b63\u786e\u5730\u5212\u5206\u3002<\/p>\n<\/li>\n<li>\n<p>\u62c6\u5206\u7279\u5f81\u56fe&#xff1a;\u5c06padding\u540e\u7684\u7279\u5f81\u56fe1x70x70x768\u6309\u7167\u7a97\u53e3\u5927\u5c0f14&#215;14\u8fdb\u884c\u62c6\u5206\u3002\u56e0\u4e3a70\/14&#061;5&#xff0c;\u6240\u4ee5\u7279\u5f81\u56fe\u53ef\u4ee5\u88ab\u62c6\u5206\u4e3a 5&#215;5\u4e2a14&#215;14\u7684\u7a97\u53e3&#xff0c;\u603b\u51715&#215;5&#061;25\u4e2a\u7a97\u53e3\u3002\u6bcf\u4e2a\u7a97\u53e3\u7684\u5c3a\u5bf8\u4e3a14x14x768\u3002<\/p>\n<\/li>\n<p>\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074941-680f3315dfb7c.png\" alt=\"\u56fe\u7247\" \/><\/p>\n<p><span class=\"token comment\"># \u5c06\u8f93\u5165\u5f20\u91cfx\u5206\u5272\u6210\u6307\u5b9a\u5927\u5c0f\u7684\u7a97\u53e3<\/span><br \/>\n<span class=\"token keyword\">def<\/span> <span class=\"token function\">window_partition<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span> window_size<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> Tuple<span class=\"token punctuation\">[<\/span>torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span> Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u83b7\u53d6\u8f93\u5165\u5f20\u91cf\u5f62\u72b6<\/span><br \/>\n    <span class=\"token comment\"># B\u8868\u793a\u6279\u6b21\u5927\u5c0f&#xff0c;H\u548cW\u8868\u793a\u9ad8\u548c\u5bbd&#xff0c;C\u8868\u793a\u901a\u9053\u6570<\/span><br \/>\n    B<span class=\"token punctuation\">,<\/span> H<span class=\"token punctuation\">,<\/span> W<span class=\"token punctuation\">,<\/span> C <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>shape<br \/>\n    <span class=\"token comment\"># \u8ba1\u7b97\u586b\u5145\u9ad8\u5ea6\u548c\u5bbd\u5ea6 pad_h \u548c pad_w&#xff0c;\u4ee5\u4f7f\u5f97\u8f93\u5165\u5c3a\u5bf8\u80fd\u88abwindow_size\u6574\u9664<\/span><br \/>\n    <span class=\"token comment\"># \u907f\u514d\u5728\u5206\u5272\u65f6\u4ea7\u751f\u975e\u5b8c\u6574\u7684\u7a97\u53e3<\/span><br \/>\n    pad_h <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span>window_size <span class=\"token operator\">&#8211;<\/span> H <span class=\"token operator\">%<\/span> window_size<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">%<\/span> window_size<br \/>\n    pad_w <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span>window_size <span class=\"token operator\">&#8211;<\/span> W <span class=\"token operator\">%<\/span> window_size<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">%<\/span> window_size<br \/>\n    <span class=\"token comment\"># \u5982\u679c\u9700\u8981\u586b\u5145&#xff0c;\u4f7f\u7528F.pad\u51fd\u6570\u5728\u5bbd\u5ea6\u548c\u9ad8\u5ea6\u65b9\u5411\u4e0a\u8fdb\u884c\u586b\u5145<\/span><br \/>\n    <span class=\"token keyword\">if<\/span> pad_h <span class=\"token operator\">&gt;<\/span> <span class=\"token number\">0<\/span> <span class=\"token keyword\">or<\/span> pad_w <span class=\"token operator\">&gt;<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> F<span class=\"token punctuation\">.<\/span>pad<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> pad_w<span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> pad_h<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u66f4\u65b0\u586b\u5145\u540e\u5f20\u91cf\u7684\u9ad8\u5ea6\u548c\u5bbd\u5ea6 Hp \u548c Wp<\/span><br \/>\n    Hp<span class=\"token punctuation\">,<\/span> Wp <span class=\"token operator\">&#061;<\/span> H <span class=\"token operator\">&#043;<\/span> pad_h<span class=\"token punctuation\">,<\/span> W <span class=\"token operator\">&#043;<\/span> pad_w<br \/>\n    <span class=\"token comment\"># \u5f20\u91cf\u91cd\u5851\u4e3a&#xff1a;B,Hp\/S,S,Wp\/S,S,C&#xff0c;\u8fd9\u6837\u53ef\u4ee5\u5c06\u8f93\u5165\u5f20\u91cf\u5206\u5272\u6210\u591a\u4e2a\u7a97\u53e3<\/span><br \/>\n    x <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>B<span class=\"token punctuation\">,<\/span> Hp <span class=\"token operator\">\/\/<\/span> window_size<span class=\"token punctuation\">,<\/span> window_size<span class=\"token punctuation\">,<\/span> Wp <span class=\"token operator\">\/\/<\/span> window_size<span class=\"token punctuation\">,<\/span> window_size<span class=\"token punctuation\">,<\/span> C<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u8c03\u6574\u5f20\u91cf\u7684\u5f62\u72b6&#xff0c;\u4f7f\u5176\u7531B,Hp\/S,Wp\/S,S,S,C&#8211;&gt;B*Hp*Wp\/(S*S),S,S,C<\/span><br \/>\n    <span class=\"token comment\"># \u8fd9\u6837\u6bcf\u4e2a\u7a97\u53e3\u90fd\u5728\u5f20\u91cf\u7684\u8fde\u7eed\u90e8\u5206<\/span><br \/>\n    windows <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>permute<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">5<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>contiguous<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> window_size<span class=\"token punctuation\">,<\/span> window_size<span class=\"token punctuation\">,<\/span> C<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u8fd4\u56de\u4e00\u4e2a\u5305\u542b\u6240\u6709\u7a97\u53e3\u7684\u5f20\u91cf\u548c\u539f\u59cb\u5f20\u91cf\u7684\u586b\u5145\u540e\u5c3a\u5bf8 (Hp, Wp)<\/span><br \/>\n    <span class=\"token keyword\">return<\/span> windows<span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">(<\/span>Hp<span class=\"token punctuation\">,<\/span> Wp<span class=\"token punctuation\">)<\/span><\/p>\n<p>\u300cUnpartition\u64cd\u4f5c\u300d<\/p>\n<p>\u5728\u975e\u5168\u5c40\u6ce8\u610f\u529b\u7684Block\u4e2d&#xff0c;\u5c06attention\u5c42\u8f93\u51fa\u7684\u7279\u5f81\u56fe1x70x70x768\u8f6c\u5316\u4e3a1x64x64x768\u7684\u7279\u5f81\u56fe&#xff0c;\u5b9e\u9645\u4e0a\u662f\u901a\u8fc7\u5207\u7247\u64cd\u4f5cx &#061; x[:1, :64, :64, :]&#xff0c;\u4ece1x70x70x768\u7684\u7279\u5f81\u56fe\u4e2d\u53d6\u51fa\u5de6\u4e0a\u89d2\u76841x64x64x768\u90e8\u5206\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074942-680f3316196ad.png\" alt=\"\u56fe\u7247\" \/><\/p>\n<p><span class=\"token comment\"># \u7528\u4e8e\u5c06window_partition\u51fd\u6570\u5206\u5272\u7684\u7a97\u53e3\u91cd\u65b0\u7ec4\u5408\u56de\u539f\u59cb\u5c3a\u5bf8\u7684\u5f20\u91cf<\/span><br \/>\n<span class=\"token keyword\">def<\/span> <span class=\"token function\">window_unpartition<\/span><span class=\"token punctuation\">(<\/span><br \/>\n    <span class=\"token comment\"># \u83b7\u53d6\u8f93\u5165\u5f20\u91cf windows \u7684\u5f62\u72b6&#xff0c;\u4ee5\u53ca\u7a97\u53e3\u5927\u5c0f window_size<\/span><br \/>\n    windows<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span> window_size<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> pad_hw<span class=\"token punctuation\">:<\/span> Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> hw<span class=\"token punctuation\">:<\/span> Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span><br \/>\n<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u539f\u59cb\u5c3a\u5bf8\u7684\u586b\u5145\u9ad8\u5ea6\u548c\u5bbd\u5ea6<\/span><br \/>\n    Hp<span class=\"token punctuation\">,<\/span> Wp <span class=\"token operator\">&#061;<\/span> pad_hw<br \/>\n    <span class=\"token comment\"># \u539f\u59cb\u5c3a\u5bf8\u7684\u65e0\u586b\u5145\u9ad8\u5ea6\u548c\u5bbd\u5ea6<\/span><br \/>\n    H<span class=\"token punctuation\">,<\/span> W <span class=\"token operator\">&#061;<\/span> hw<br \/>\n    <span class=\"token comment\"># \u4ece\u7a97\u53e3\u5f20\u91cf\u7684\u603b\u5927\u5c0f\u4e2d\u8ba1\u7b97\u51fa\u539f\u59cb\u6279\u91cf\u5927\u5c0f B<\/span><br \/>\n    B <span class=\"token operator\">&#061;<\/span> windows<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">\/\/<\/span> <span class=\"token punctuation\">(<\/span>Hp <span class=\"token operator\">*<\/span> Wp <span class=\"token operator\">\/\/<\/span> window_size <span class=\"token operator\">\/\/<\/span> window_size<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u91cd\u5851\u7a97\u53e3\u5f20\u91cf&#xff1a;B*Hp*Wp\/(S*S),S,S,C&#8211;&gt;B,Hp\/S,Wp\/S,S,S,C<\/span><br \/>\n    x <span class=\"token operator\">&#061;<\/span> windows<span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>B<span class=\"token punctuation\">,<\/span> Hp <span class=\"token operator\">\/\/<\/span> window_size<span class=\"token punctuation\">,<\/span> Wp <span class=\"token operator\">\/\/<\/span> window_size<span class=\"token punctuation\">,<\/span> window_size<span class=\"token punctuation\">,<\/span> window_size<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u518d\u6b21\u91cd\u5851\u5f20\u91cf&#xff1a;B,Hp\/S,Wp\/S,S,S,C&#8211;&gt;B,Hp,Wp,C<\/span><br \/>\n    x <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>permute<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">5<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>contiguous<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>B<span class=\"token punctuation\">,<\/span> Hp<span class=\"token punctuation\">,<\/span> Wp<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5982\u679c\u539f\u59cb\u5c3a\u5bf8\u5c0f\u4e8e\u586b\u5145\u540e\u7684\u5c3a\u5bf8<\/span><br \/>\n    <span class=\"token keyword\">if<\/span> Hp <span class=\"token operator\">&gt;<\/span> H <span class=\"token keyword\">or<\/span> Wp <span class=\"token operator\">&gt;<\/span> W<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u901a\u8fc7\u5207\u7247 x[:, :H, :W, :] \u53bb\u9664\u586b\u5145\u90e8\u5206&#xff0c;\u53ea\u4fdd\u7559\u539f\u59cb\u5927\u5c0f\u7684\u533a\u57df<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span>H<span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span>W<span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">.<\/span>contiguous<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># B,H,W,C<\/span><br \/>\n    <span class=\"token comment\"># \u8fd4\u56de\u5408\u5e76\u540e\u7684\u5f20\u91cf&#xff0c;\u5176\u5f62\u72b6\u4e3a (B,H,W,C)&#xff0c;\u5373\u539f\u59cb\u7684\u6279\u91cf\u5927\u5c0f\u3001\u9ad8\u5ea6\u3001\u5bbd\u5ea6\u548c\u901a\u9053\u6570<\/span><br \/>\n    <span class=\"token keyword\">return<\/span> x<\/p>\n<p>Encoder Block\u8fc7\u7a0b\u5982\u4e0b\u56fe\u6240\u793a&#xff1a;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074942-680f33164c20f.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p>window_partition\u5c06\u8f93\u5165\u7279\u5f81\u7684\u5c3a\u5bf8\u4ece(H, W)\u8c03\u6574\u4e3a(S, S)\u7684\u7a97\u53e3&#xff0c;\u5176\u4e2dS\u662f\u7a97\u53e3\u5927\u5c0f\u3002\u8fd9\u79cd\u8c03\u6574\u662f\u4e3a\u4e86\u5728\u591a\u5934\u6ce8\u610f\u529b&#xff08;Multi-Head Attention&#xff09;\u4e2d\u5c06\u76f8\u5bf9\u4f4d\u7f6e\u5d4c\u5165\u6dfb\u52a0\u5230\u6ce8\u610f\u529b\u56fe(attn)\u3002\u7136\u800c&#xff0c;\u5e76\u975e\u6240\u6709Transformer Block\u90fd\u9700\u8981\u5728\u6ce8\u610f\u529b\u56fe\u4e2d\u5d4c\u5165\u76f8\u5bf9\u4f4d\u7f6e\u4fe1\u606f\u3002 window_unpartition \u51fd\u6570\u7684\u4f5c\u7528\u662f\u5c06\u7ecf\u8fc7\u6ce8\u610f\u529b\u8ba1\u7b97\u7684\u7a97\u53e3\u7279\u5f81\u91cd\u65b0\u7ec4\u5408\u56de\u539f\u59cb\u5c3a\u5bf8(S\u00d7S\u2013&gt;H\u00d7W)\u3002 Hp\u548cWp\u662fS\u7684\u6574\u6570\u500d<\/p>\n<p>Multi-Head Attention<\/p>\n<p>\u5148\u6765\u770bAttention&#xff0c;\u7ed3\u6784\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074942-680f33169ac2d.png\" alt=\"\u56fe\u7247\" width=\"200\" \/><\/p>\n<p>Attention\u4e2dq\u3001k\u548cv\u7684\u4f5c\u7528&#xff1a;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074942-680f3316cb514.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p>\u4ee3\u7801\u5b9e\u73b0\u5982\u4e0b&#xff1a;<\/p>\n<p><span class=\"token keyword\">class<\/span> <span class=\"token class-name\">Attention<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token triple-quoted-string string\">&#034;&#034;&#034;Multi-head Attention block with relative position embeddings.&#034;&#034;&#034;<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><br \/>\n        self<span class=\"token punctuation\">,<\/span><br \/>\n        dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>               <span class=\"token comment\"># \u8f93\u5165\u901a\u9053\u6570<\/span><br \/>\n        num_heads<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">8<\/span><span class=\"token punctuation\">,<\/span>     <span class=\"token comment\"># head\u6570\u76ee<\/span><br \/>\n        qkv_bias<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">bool<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">True<\/span><span class=\"token punctuation\">,<\/span>  <span class=\"token comment\"># \u662f\u5426\u5728QKV\u7ebf\u6027\u53d8\u6362\u4e2d\u4f7f\u7528\u504f\u7f6e\u9879&#xff0c;\u9ed8\u8ba4\u4e3aTrue<\/span><br \/>\n        use_rel_pos<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">bool<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">False<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token comment\">#\u662f\u5426\u4f7f\u7528\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801&#xff0c;\u9ed8\u8ba4\u4e3aFalse<\/span><br \/>\n        rel_pos_zero_init<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">bool<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">True<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token comment\">#\u5982\u679c\u4f7f\u7528\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801&#xff0c;\u662f\u5426\u4ee5\u96f6\u521d\u59cb\u5316&#xff0c;\u9ed8\u8ba4\u4e3aTrue<\/span><br \/>\n        input_size<span class=\"token punctuation\">:<\/span> Optional<span class=\"token punctuation\">[<\/span>Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">,<\/span>       <span class=\"token comment\"># \u53ef\u9009\u53c2\u6570&#xff0c;\u7528\u4e8e\u6307\u5b9a\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u7684\u5c3a\u5bf8&#xff0c;\u53ea\u6709\u5728\u4f7f\u7528\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u65f6\u624d\u9700\u8981<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>num_heads <span class=\"token operator\">&#061;<\/span> num_heads <span class=\"token comment\">#\u8f93\u5165head\u6570\u76ee<\/span><br \/>\n        head_dim <span class=\"token operator\">&#061;<\/span> dim <span class=\"token operator\">\/\/<\/span> num_heads <span class=\"token comment\">#\u6bcf\u4e2ahead\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>scale <span class=\"token operator\">&#061;<\/span> head_dim<span class=\"token operator\">**<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">0.5<\/span> <span class=\"token comment\">#\u7528\u4e8e\u7f29\u653e\u6ce8\u610f\u529b\u5f97\u5206\u7684\u56e0\u5b50&#xff0c;\u4ee5\u907f\u514d\u6570\u503c\u6ea2\u51fa&#xff0c;\u53d6\u503c\u4e3ahead_dim\u7684\u5e73\u65b9\u6839\u7684\u5012\u6570<\/span><br \/>\n        <span class=\"token comment\">#\u4e00\u4e2a\u5168\u8fde\u63a5\u5c42&#xff08;nn.Linear&#xff09;&#xff0c;\u5c06\u8f93\u5165\u6620\u5c04\u5230Q\u3001K\u3001V\u7684\u7ec4\u5408<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>qkv <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>dim<span class=\"token punctuation\">,<\/span> dim <span class=\"token operator\">*<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span> bias<span class=\"token operator\">&#061;<\/span>qkv_bias<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\">#  \u4e00\u4e2a\u5168\u8fde\u63a5\u5c42&#xff0c;\u7528\u4e8e\u5c06\u6ce8\u610f\u529b\u673a\u5236\u7684\u8f93\u51fa\u6295\u5f71\u56de\u539f\u59cb\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>proj <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>dim<span class=\"token punctuation\">,<\/span> dim<span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>use_rel_pos <span class=\"token operator\">&#061;<\/span> use_rel_pos<br \/>\n        <span class=\"token keyword\">if<\/span> self<span class=\"token punctuation\">.<\/span>use_rel_pos<span class=\"token punctuation\">:<\/span>        <span class=\"token comment\"># \u4f7f\u7528\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n            <span class=\"token keyword\">assert<\/span> <span class=\"token punctuation\">(<\/span><br \/>\n                input_size <span class=\"token keyword\">is<\/span> <span class=\"token keyword\">not<\/span> <span class=\"token boolean\">None<\/span><br \/>\n            <span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token string\">&#034;Input size must be provided if using relative positional encoding.&#034;<\/span><br \/>\n            <span class=\"token comment\"># \u521d\u59cb\u5316\u6c34\u5e73\u65b9\u5411(rel_pos_h)\u548c\u5782\u76f4\u65b9\u5411(rel_pos_w)\u7684\u76f8\u5bf9\u4f4d\u7f6e\u5d4c\u5165<\/span><br \/>\n            <span class=\"token comment\"># 2S-1,Epos<\/span><br \/>\n            <span class=\"token comment\"># \u8f93\u5165\u5c3a\u5bf8\u4e3a(H, W)&#xff0c;\u5219\u6c34\u5e73\u65b9\u5411\u7684\u4f4d\u7f6e\u5d4c\u5165\u957f\u5ea6\u4e3a2*H-1&#xff0c;\u5782\u76f4\u65b9\u5411\u7684\u4f4d\u7f6e\u5d4c\u5165\u957f\u5ea6\u4e3a2*W-1<\/span><br \/>\n            <span class=\"token comment\"># \u6bcf\u4e2a\u4f4d\u7f6e\u5d4c\u5165\u7684\u7ef4\u5ea6\u4e3ahead_dim<\/span><br \/>\n            <span class=\"token comment\"># \u8fd9\u4e9b\u4f4d\u7f6e\u5d4c\u5165\u4ee5\u6a21\u578b\u53c2\u6570\u7684\u5f62\u5f0f\u5b9a\u4e49(nn.Parameter)&#xff0c;\u610f\u5473\u7740\u5b83\u4eec\u4f1a\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u88ab\u5b66\u4e60\u548c\u66f4\u65b0<\/span><br \/>\n            self<span class=\"token punctuation\">.<\/span>rel_pos_h <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Parameter<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>zeros<span class=\"token punctuation\">(<\/span><span class=\"token number\">2<\/span> <span class=\"token operator\">*<\/span> input_size<span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#8211;<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> head_dim<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n            self<span class=\"token punctuation\">.<\/span>rel_pos_w <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Parameter<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>zeros<span class=\"token punctuation\">(<\/span><span class=\"token number\">2<\/span> <span class=\"token operator\">*<\/span> input_size<span class=\"token punctuation\">[<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#8211;<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> head_dim<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u8f93\u5165\u5f20\u91cfx\u7684\u5f62\u72b6\u4e3a(B, H, W, C)&#xff0c;\u5176\u4e2dB\u662f\u6279\u6b21\u5927\u5c0f&#xff0c;H\u548cW\u662f\u9ad8\u5ea6\u548c\u5bbd\u5ea6&#xff0c;C\u662f\u901a\u9053\u6570&#xff08;\u5373dim&#xff09;<\/span><br \/>\n        B<span class=\"token punctuation\">,<\/span> H<span class=\"token punctuation\">,<\/span> W<span class=\"token punctuation\">,<\/span> _ <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>shape<br \/>\n        <span class=\"token comment\"># \u4f7f\u7528qkv\u5c42\u5c06x\u8f6c\u6362\u4e3aQ\u3001K\u3001V\u7684\u7ec4\u5408&#xff0c;\u7136\u540e\u901a\u8fc7\u91cd\u5851\u548c\u91cd\u65b0\u6392\u5217\u6765\u51c6\u5907\u591a\u5934\u6ce8\u610f\u529b\u8ba1\u7b97<\/span><br \/>\n        <span class=\"token comment\"># qkv with shape (3, B, nHead, H * W, C)<\/span><br \/>\n        qkv <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>qkv<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>reshape<span class=\"token punctuation\">(<\/span>B<span class=\"token punctuation\">,<\/span> H <span class=\"token operator\">*<\/span> W<span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>num_heads<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>permute<span class=\"token punctuation\">(<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># q, k, v with shape (B * nHead, H * W, C)<\/span><br \/>\n        q<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">,<\/span> v <span class=\"token operator\">&#061;<\/span> qkv<span class=\"token punctuation\">.<\/span>reshape<span class=\"token punctuation\">(<\/span><span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span> B <span class=\"token operator\">*<\/span> self<span class=\"token punctuation\">.<\/span>num_heads<span class=\"token punctuation\">,<\/span> H <span class=\"token operator\">*<\/span> W<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>unbind<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># attn with shape (B * nHead, H * W,  H * W)<\/span><br \/>\n        <span class=\"token comment\"># \u8ba1\u7b97\u6ce8\u610f\u529b\u5206\u6570<\/span><br \/>\n        <span class=\"token comment\"># q * self.scale: q\u662f\u67e5\u8be2\u5411\u91cf&#xff08;query vectors&#xff09;&#xff0c;\u5f62\u72b6\u4e3a(B * nHead, H * W, C)&#xff0c;\u5176\u4e2dB\u662f\u6279\u6b21\u5927\u5c0f&#xff0c;nHead\u662f\u6ce8\u610f\u529b\u5934\u7684\u6570\u91cf&#xff0c;H * W\u662f\u5e8f\u5217\u7684\u957f\u5ea6&#xff0c;C\u662f\u6bcf\u4e2a\u4f4d\u7f6e\u7684\u7279\u5f81\u7ef4\u5ea6<\/span><br \/>\n        <span class=\"token comment\"># self.scale\u662f\u7528\u4e8e\u7f29\u653e\u6ce8\u610f\u529b\u5206\u6570\u7684\u56e0\u5b50&#xff0c;\u901a\u5e38\u53d6head_dim\u7684\u5e73\u65b9\u6839\u7684\u5012\u6570&#xff0c;\u4ee5\u9632\u6b62\u6570\u503c\u8fc7\u5927<\/span><br \/>\n        <span class=\"token comment\"># \u4e58\u4ee5self.scale\u662f\u4e3a\u4e86\u7a33\u5b9a\u8ba1\u7b97\u5e76\u9632\u6b62\u68af\u5ea6\u6d88\u5931<\/span><br \/>\n        <span class=\"token comment\"># k.transpose(-2, -1): k\u662f\u952e\u5411\u91cf&#xff08;key vectors&#xff09;&#xff0c;\u5f62\u72b6\u4e0eq\u76f8\u540c\u3002transpose(-2, -1)\u662f\u5bf9k\u8fdb\u884c\u8f6c\u7f6e\u64cd\u4f5c&#xff0c;\u5373\u5c06\u6700\u540e\u4e00\u4e2a\u548c\u5012\u6570\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\u4e92\u6362&#xff0c;\u76ee\u7684\u662f\u8ba9q\u548ck\u5728\u8ba1\u7b97\u70b9\u79ef\u65f6\u7684\u7ef4\u5ea6\u5339\u914d\u3002\u8f6c\u7f6e\u540e\u7684k\u5f62\u72b6\u53d8\u4e3a(B * nHead, C, H * W)<\/span><br \/>\n        <span class=\"token comment\"># \u5c06q\u548c\u8f6c\u7f6e\u540e\u7684k\u8fdb\u884c\u77e9\u9635\u4e58\u6cd5\u3002\u8ba1\u7b97\u6bcf\u4e2a\u67e5\u8be2\u4f4d\u7f6eq\u4e0e\u6240\u6709\u952e\u4f4d\u7f6ek\u7684\u70b9\u79ef&#xff0c;\u751f\u6210\u4e00\u4e2a\u5f62\u72b6\u4e3a(B * nHead, H * W, H * W)\u7684\u6ce8\u610f\u529b\u5206\u6570\u77e9\u9635attn\u3002\u6bcf\u4e2a\u4f4d\u7f6ei\u548cj\u7684\u6ce8\u610f\u529b\u5206\u6570\u8868\u793aq_i\u4e0ek_j\u7684\u76f8\u4f3c\u5ea6<\/span><br \/>\n        attn <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span>q <span class=\"token operator\">*<\/span> self<span class=\"token punctuation\">.<\/span>scale<span class=\"token punctuation\">)<\/span> &#064; k<span class=\"token punctuation\">.<\/span>transpose<span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5982\u679c\u542f\u7528\u4e86\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n        <span class=\"token keyword\">if<\/span> self<span class=\"token punctuation\">.<\/span>use_rel_pos<span class=\"token punctuation\">:<\/span><br \/>\n            <span class=\"token comment\"># (H, W)\u4ee3\u8868\u8f93\u5165\u5e8f\u5217\u7684\u5c3a\u5bf8&#xff0c;\u8fd9\u91cc\u5047\u8bbeH\u548cW\u662f\u76f8\u7b49\u7684(S\u00d7S)&#xff0c;\u5373\u8f93\u5165\u662f\u4e00\u4e2a\u6b63\u65b9\u5f62\u7f51\u683c&#xff08;\u4f8b\u5982&#xff0c;\u56fe\u50cf\u7684\u50cf\u7d20\u7f51\u683c&#xff09;<\/span><br \/>\n            <span class=\"token comment\"># attn: \u4e0a\u8ff0\u8ba1\u7b97\u5f97\u5230\u7684\u6ce8\u610f\u529b\u5206\u6570\u77e9\u9635&#xff0c;\u5f62\u72b6\u4e3a(B * nHead, H * W, H * W)<\/span><br \/>\n            <span class=\"token comment\"># q: \u67e5\u8be2\u5411\u91cf&#xff0c;\u5f62\u72b6\u4e3a(B * nHead, H * W, C)<\/span><br \/>\n            <span class=\"token comment\"># self.rel_pos_h\u548cself.rel_pos_w: \u5206\u522b\u8868\u793a\u6c34\u5e73\u548c\u5782\u76f4\u65b9\u5411\u4e0a\u7684\u76f8\u5bf9\u4f4d\u7f6e\u5d4c\u5165&#xff0c;\u5f62\u72b6\u5206\u522b\u4e3a(2 * S &#8211; 1, head_dim)<\/span><br \/>\n            <span class=\"token comment\"># (H, W): \u8f93\u5165\u5e8f\u5217\u7684\u5c3a\u5bf8&#xff0c;\u7528\u4e8e\u6307\u5bfc\u76f8\u5bf9\u4f4d\u7f6e\u5d4c\u5165\u7684\u8ba1\u7b97<\/span><br \/>\n            attn <span class=\"token operator\">&#061;<\/span> add_decomposed_rel_pos<span class=\"token punctuation\">(<\/span>attn<span class=\"token punctuation\">,<\/span> q<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>rel_pos_h<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>rel_pos_w<span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">(<\/span>H<span class=\"token punctuation\">,<\/span> W<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">(<\/span>H<span class=\"token punctuation\">,<\/span> W<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u751f\u6210\u7684\u6ce8\u610f\u529b\u5206\u6570\u77e9\u9635attn\u968f\u540e\u4f1a\u7ecf\u8fc7Softmax\u51fd\u6570&#xff0c;\u5c06\u6bcf\u4e2a\u4f4d\u7f6e\u7684\u5206\u6570\u5f52\u4e00\u5316\u5230[0, 1]\u533a\u95f4&#xff0c;\u5f62\u6210\u4e00\u4e2a\u6982\u7387\u5206\u5e03<\/span><br \/>\n        attn <span class=\"token operator\">&#061;<\/span> attn<span class=\"token punctuation\">.<\/span>softmax<span class=\"token punctuation\">(<\/span>dim<span class=\"token operator\">&#061;<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u52a0\u6743\u6c42\u548c: <\/span><br \/>\n        <span class=\"token comment\"># \u4f7f\u7528attn &#064; v\u8ba1\u7b97\u52a0\u6743\u548c&#xff0c;\u5176\u4e2d&#064;\u8868\u793a\u77e9\u9635\u4e58\u6cd5,v\u662f\u503c\u5411\u91cf&#xff08;value vectors&#xff09;&#xff0c;\u5f62\u72b6\u4e3a(B * nHead, H * W, C)<\/span><br \/>\n        <span class=\"token comment\"># \u6ce8\u610f\u529b\u6743\u91cd\u77e9\u9635attn&#xff08;\u5f62\u72b6\u4e3a(B * nHead, H * W, H * W)&#xff09;\u4e0ev\u6309\u5143\u7d20\u76f8\u4e58\u540e&#xff0c;\u518d\u8fdb\u884c\u77e9\u9635\u4e58\u6cd5&#xff0c;\u5f97\u5230\u52a0\u6743\u540e\u7684\u503c\u5411\u91cf&#xff0c;\u5f62\u72b6\u4e3a(B * nHead, H * W, C)<\/span><br \/>\n        <span class=\"token comment\"># \u4f7f\u7528.view()\u5c06\u52a0\u6743\u540e\u7684\u503c\u5411\u91cf\u91cd\u5851\u4e3a(B, self.num_heads, H, W, -1)&#xff0c;\u7136\u540e\u4f7f\u7528.permute(0, 2, 3, 1, 4)\u8fdb\u884c\u91cd\u6392&#xff0c;\u5c06self.num_heads\u79fb\u52a8\u5230\u7b2c\u56db\u4e2a\u7ef4\u5ea6\u3002\u6700\u540e&#xff0c;\u4f7f\u7528.reshape(B, H, W, -1)\u5c06\u7ed3\u679c\u8fdb\u4e00\u6b65\u91cd\u5851\u4e3a(B, H, W, -1)&#xff0c;\u4e0e\u8f93\u5165\u5f20\u91cf\u7684\u5f62\u72b6\u4e00\u81f4&#xff0c;\u4f46\u4fdd\u7559\u4e86\u591a\u5934\u6ce8\u610f\u529b\u7684\u8f93\u51fa<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span>attn &#064; v<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>B<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>num_heads<span class=\"token punctuation\">,<\/span> H<span class=\"token punctuation\">,<\/span> W<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>permute<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>reshape<span class=\"token punctuation\">(<\/span>B<span class=\"token punctuation\">,<\/span> H<span class=\"token punctuation\">,<\/span> W<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u4f7f\u7528self.proj&#xff08;\u4e00\u4e2a\u5168\u8fde\u63a5\u5c42&#xff0c;\u5f62\u72b6\u4e3a(dim, dim)&#xff09;\u5bf9\u4e0a\u8ff0\u5904\u7406\u540e\u7684\u5f20\u91cf\u8fdb\u884c\u7ebf\u6027\u6295\u5f71&#xff0c;\u4ee5\u5c06\u5176\u6295\u5f71\u56de\u539f\u59cb\u7684\u7279\u5f81\u7ef4\u5ea6<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>proj<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6700\u7ec8&#xff0c;\u8fd4\u56de\u7ecf\u8fc7\u7ebf\u6027\u6295\u5f71\u7684\u5f20\u91cfx\u4f5c\u4e3a\u6ce8\u610f\u529b\u6a21\u5757\u7684\u8f93\u51fa<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> x<\/p>\n<p>\u5728\u591a\u5934\u6ce8\u610f\u529b&#xff08;Multi-Head Attention&#xff09;\u6a21\u5757\u4e2d&#xff0c;\u8f93\u5165\u7279\u5f81F(N\u00d7E)\u8868\u793a\u4e00\u4e2a\u5e8f\u5217&#xff0c;\u5176\u4e2dN\u662f\u5e8f\u5217\u4e2d\u7684\u5143\u7d20\u6570\u91cf&#xff0c;E\u662f\u6bcf\u4e2a\u5143\u7d20\u7684\u7279\u5f81\u7ef4\u5ea6\u3002\u5177\u4f53\u6d41\u7a0b\u5982\u4e0b\u3002<\/p>\n<li>\u9996\u5148\u5c06\u6bcf\u4e2atoken\u7684qkv\u7279\u5f81\u7ef4\u5ea6embed_dim\u5747\u62c6\u5206\u5230\u6bcf\u4e2ahead\u4e0a\u3002<\/li>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074943-680f33170c924.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<li>\u6bcf\u4e2ahead\u5206\u522b\u901a\u8fc7q\u548ck\u8ba1\u7b97\u5f97\u5230\u6743\u91cdw&#xff0c;\u6743\u91cdw\u548cv\u5f97\u5230\u8f93\u51faoutput&#xff0c;\u5408\u5e76\u6240\u6709head\u7684output\u5f97\u5230\u6700\u7ec8\u7684output\u3002<\/li>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074943-680f33175d930.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p>get_rel_pos\u7528\u4e8e\u8ba1\u7b97\u67e5\u8be2&#xff08;query&#xff09;\u548c\u952e&#xff08;key&#xff09;\u4e4b\u95f4\u5728\u4e8c\u7ef4\u7a7a\u95f4\u4e2d\u7684\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801&#xff0c;\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074943-680f3317a0082.png\" alt=\"\u56fe\u7247\" width=\"300\" \/><\/p>\n<p>\u5b9e\u73b0\u4ee3\u7801&#xff1a;<\/p>\n<p><span class=\"token keyword\">def<\/span> <span class=\"token function\">get_rel_pos<\/span><span class=\"token punctuation\">(<\/span>q_size<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> k_size<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> rel_pos<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u8868\u793a\u67e5\u8be2&#xff08;query&#xff09;\u548c\u952e&#xff08;key&#xff09;\u5728\u4e8c\u7ef4\u7a7a\u95f4\u4e2d\u7684\u6700\u5927\u76f8\u5bf9\u8ddd\u79bb<\/span><br \/>\n    <span class=\"token comment\"># max(q_size, k_size)&#xff1a;\u53d6\u67e5\u8be2\u7684\u5bbd\u5ea6q_size\u548c\u952e\u7684\u5bbd\u5ea6k_size\u4e2d\u7684\u8f83\u5927\u503c<\/span><br \/>\n    <span class=\"token comment\"># \u5982\u679cq_size\u548ck_size\u90fd\u4e3aS&#xff0c;\u5219\u6700\u5927\u7684\u6b63\u5411\u8ddd\u79bb\u662fS-1&#xff0c;\u6700\u5927\u7684\u8d1f\u5411\u8ddd\u79bb\u4e5f\u662fS-1&#xff0c;\u6240\u4ee5\u603b\u7684\u6700\u5927\u8ddd\u79bb\u662f2 * S<\/span><br \/>\n    <span class=\"token comment\"># &#8211; 1&#xff1a;\u51cf\u53bb1\u662f\u56e0\u4e3a\u5728\u8ba1\u7b97\u76f8\u5bf9\u4f4d\u7f6e\u65f6&#xff0c;0\u88ab\u5305\u542b\u5728\u5185&#xff0c;\u6240\u4ee5\u6700\u5927\u8ddd\u79bb\u662f2 * S &#8211; 1<\/span><br \/>\n    max_rel_dist <span class=\"token operator\">&#061;<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">2<\/span> <span class=\"token operator\">*<\/span> <span class=\"token builtin\">max<\/span><span class=\"token punctuation\">(<\/span>q_size<span class=\"token punctuation\">,<\/span> k_size<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5982\u679crel_pos\u7684\u5f62\u72b6\u7684\u7b2c0\u4e2a\u7ef4\u5ea6&#xff08;\u5373\u957f\u5ea6&#xff09;\u4e0d\u7b49\u4e8emax_rel_dist&#xff0c;\u8bf4\u660e\u9700\u8981\u8fdb\u884c\u63d2\u503c<\/span><br \/>\n    <span class=\"token keyword\">if<\/span> rel_pos<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">!&#061;<\/span> max_rel_dist<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u4f7f\u7528F.interpolate\u8fdb\u884c\u7ebf\u6027\u63d2\u503c<\/span><br \/>\n        rel_pos_resized <span class=\"token operator\">&#061;<\/span> F<span class=\"token punctuation\">.<\/span>interpolate<span class=\"token punctuation\">(<\/span><br \/>\n            <span class=\"token comment\"># 1,N,Ep &#8211;&gt; 1,Ep,N &#8211;&gt; 1,Ep,2S-1<\/span><br \/>\n            <span class=\"token comment\"># \u5c06rel_pos\u91cd\u5851\u4e3a(1, N, Ep)&#xff0c;\u5176\u4e2dN\u662f\u539f\u59cb\u7684\u957f\u5ea6&#xff0c;Ep\u662f\u6bcf\u4e2a\u4f4d\u7f6e\u7f16\u7801\u7684\u7279\u5f81\u7ef4\u5ea6<\/span><br \/>\n            <span class=\"token comment\"># \u901a\u8fc7permute(0, 2, 1)\u8fdb\u884c\u8f6c\u7f6e&#xff0c;\u4f7f\u5176\u5f62\u72b6\u53d8\u4e3a(1, Ep, N)<\/span><br \/>\n            rel_pos<span class=\"token punctuation\">.<\/span>reshape<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> rel_pos<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>permute<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n            <span class=\"token comment\"># \u8bbe\u7f6e\u63d2\u503c\u7684\u76ee\u6807\u957f\u5ea6\u4e3amax_rel_dist<\/span><br \/>\n            size<span class=\"token operator\">&#061;<\/span>max_rel_dist<span class=\"token punctuation\">,<\/span><br \/>\n            <span class=\"token comment\"># \u6307\u5b9a\u63d2\u503c\u65b9\u6cd5\u4e3a\u7ebf\u6027\u63d2\u503c<\/span><br \/>\n            mode<span class=\"token operator\">&#061;<\/span><span class=\"token string\">&#034;linear&#034;<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># Ep,2S-1 &#8211;&gt; 2S-1,Ep<\/span><br \/>\n        <span class=\"token comment\"># \u63d2\u503c\u540e\u7684rel_pos\u5f62\u72b6\u4e3a(1, Ep, max_rel_dist)&#xff0c;\u901a\u8fc7reshape(-1, max_rel_dist)\u5c06\u5176\u91cd\u5851\u4e3a(Ep, max_rel_dist)<\/span><br \/>\n        <span class=\"token comment\"># \u518d\u901a\u8fc7permute(1, 0)\u8f6c\u7f6e\u4e3a(max_rel_dist, Ep)<\/span><br \/>\n        rel_pos_resized <span class=\"token operator\">&#061;<\/span> rel_pos_resized<span class=\"token punctuation\">.<\/span>reshape<span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> max_rel_dist<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>permute<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token keyword\">else<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u5982\u679crel_pos\u7684\u957f\u5ea6\u4e0emax_rel_dist\u76f8\u7b49&#xff0c;\u8bf4\u660e\u5df2\u7ecf\u8db3\u591f\u8986\u76d6\u6240\u6709\u53ef\u80fd\u7684\u76f8\u5bf9\u4f4d\u7f6e&#xff0c;\u56e0\u6b64\u76f4\u63a5\u4f7f\u7528rel_pos&#xff0c;\u4e0d\u8fdb\u884c\u4efb\u4f55\u5904\u7406<\/span><br \/>\n        rel_pos_resized <span class=\"token operator\">&#061;<\/span> rel_pos<\/p>\n<p>    <span class=\"token comment\"># \u5982\u679cq\u548ck\u957f\u5ea6\u503c\u4e0d\u540c&#xff0c;\u5219\u7528\u77ed\u8fb9\u957f\u5ea6\u7f29\u653e\u5750\u6807<\/span><br \/>\n    <span class=\"token comment\"># \u521b\u5efa\u67e5\u8be2\u5750\u6807q_coords<\/span><br \/>\n    <span class=\"token comment\"># torch.arange(q_size)\u751f\u6210\u4e00\u4e2a\u4ece0\u5230q_size &#8211; 1\u7684\u6574\u6570\u5e8f\u5217&#xff0c;\u8868\u793aq_size\u4e2a\u4f4d\u7f6e<\/span><br \/>\n    <span class=\"token comment\"># [:, None]\u5728\u5e8f\u5217\u672b\u5c3e\u6dfb\u52a0\u4e00\u4e2a\u7ef4\u5ea6&#xff0c;\u4f7f\u5176\u5f62\u72b6\u4e3a(q_size, 1)&#xff0c;\u8fd9\u6837\u53ef\u4ee5\u65b9\u4fbf\u4e0e\u4e00\u4e2a\u6807\u91cf\u8fdb\u884c\u9010\u5143\u7d20\u4e58\u6cd5<\/span><br \/>\n    <span class=\"token comment\"># max(k_size \/ q_size, 1.0)\u8ba1\u7b97\u6bd4\u4f8b\u56e0\u5b50&#xff0c;\u5982\u679ck_size\u5927\u4e8eq_size&#xff0c;\u5219\u4f7f\u7528k_size \/ q_size&#xff0c;\u5426\u5219\u4f7f\u75281.0<\/span><br \/>\n    <span class=\"token comment\"># \u8fd9\u786e\u4fdd\u4e86\u5728q_size\u5c0f\u4e8ek_size\u7684\u60c5\u51b5\u4e0b&#xff0c;q_coords\u7684\u5750\u6807\u4f1a\u88ab\u9002\u5f53\u653e\u5927&#xff0c;\u4ee5\u5339\u914dk_coords\u7684\u5c3a\u5ea6<\/span><br \/>\n    q_coords <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>arange<span class=\"token punctuation\">(<\/span>q_size<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">*<\/span> <span class=\"token builtin\">max<\/span><span class=\"token punctuation\">(<\/span>k_size <span class=\"token operator\">\/<\/span> q_size<span class=\"token punctuation\">,<\/span> <span class=\"token number\">1.0<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u521b\u5efa\u952e\u5750\u6807k_coords<\/span><br \/>\n    k_coords <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>arange<span class=\"token punctuation\">(<\/span>k_size<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">[<\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">*<\/span> <span class=\"token builtin\">max<\/span><span class=\"token punctuation\">(<\/span>q_size <span class=\"token operator\">\/<\/span> k_size<span class=\"token punctuation\">,<\/span> <span class=\"token number\">1.0<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># S,S<\/span><br \/>\n    <span class=\"token comment\"># \u8ba1\u7b97\u4e86\u67e5\u8be2&#xff08;query&#xff09;\u548c\u952e&#xff08;key&#xff09;\u5728\u4e8c\u7ef4\u7a7a\u95f4\u4e2d\u7684\u76f8\u5bf9\u5750\u6807relative_coords<\/span><br \/>\n    <span class=\"token comment\"># (q_coords &#8211; k_coords):\u6bcf\u4e2a\u67e5\u8be2\u4f4d\u7f6e\u76f8\u5bf9\u4e8e\u6bcf\u4e2a\u952e\u4f4d\u7f6e\u7684\u6c34\u5e73\u8ddd\u79bb<\/span><br \/>\n    <span class=\"token comment\"># (k_size &#8211; 1) * max(q_size \/ k_size, 1.0)&#xff1a;\u8ba1\u7b97\u4e86\u4e00\u4e2a\u504f\u79fb\u91cf&#xff0c;\u7528\u4e8e\u786e\u4fdd\u76f8\u5bf9\u5750\u6807\u5728\u6b63\u786e\u7684\u8303\u56f4\u5185<\/span><br \/>\n    <span class=\"token comment\"># (q_coords &#8211; k_coords) &#043; (k_size &#8211; 1) * max(q_size \/ k_size, 1.0)&#xff1a;\u5c06\u8ba1\u7b97\u51fa\u7684\u5dee\u503c\u548c\u504f\u79fb\u91cf\u76f8\u52a0&#xff0c;\u5f97\u5230\u6700\u7ec8\u7684\u76f8\u5bf9\u5750\u6807relative_coords<\/span><br \/>\n    relative_coords <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span>q_coords <span class=\"token operator\">&#8211;<\/span> k_coords<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#043;<\/span> <span class=\"token punctuation\">(<\/span>k_size <span class=\"token operator\">&#8211;<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">*<\/span> <span class=\"token builtin\">max<\/span><span class=\"token punctuation\">(<\/span>q_size <span class=\"token operator\">\/<\/span> k_size<span class=\"token punctuation\">,<\/span> <span class=\"token number\">1.0<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token comment\"># tensor\u7d22\u5f15\u662ftensor\u65f6,\u5373tensor1[tensor2]<\/span><br \/>\n    <span class=\"token comment\"># \u5047\u8bbetensor2\u67d0\u4e2a\u5177\u4f53\u4f4d\u7f6e\u503c\u662f2,\u5219tensor1[2]\u4f4d\u7f6e\u7684tensor1\u5207\u7247\u66ff\u6362tensor2\u4e2d\u76842<\/span><br \/>\n    <span class=\"token comment\"># tensor1-&gt;shape 5,5,3 tensor2-&gt;shape 2,2,3 tensor1\u5207\u7247-&gt;shape 5,3 tensor1[tensor2]-&gt;shape 2,2,3,5,3<\/span><br \/>\n    <span class=\"token comment\"># tensor1-&gt;shape 5,5 tensor2-&gt;shape 3,2,3 tensor1\u5207\u7247-&gt;shape 5 tensor1[tensor2]-&gt;shape 3,2,3,5<\/span><\/p>\n<p>    <span class=\"token comment\"># 2S-1,Ep&#8211;&gt;S,S,Ep<\/span><br \/>\n    <span class=\"token keyword\">return<\/span> rel_pos_resized<span class=\"token punctuation\">[<\/span>relative_coords<span class=\"token punctuation\">.<\/span><span class=\"token builtin\">long<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">]<\/span><\/p>\n<p>add_decomposed_rel_pos\u4e3aatten\u6ce8\u610f\u529b\u7279\u5f81\u6dfb\u52a0\u76f8\u5bf9\u4f4d\u7f6e\u7684\u5d4c\u5165\u7279\u5f81&#xff0c;\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074943-680f3317d110d.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p><span class=\"token keyword\">def<\/span> <span class=\"token function\">add_decomposed_rel_pos<\/span><span class=\"token punctuation\">(<\/span><br \/>\n    <span class=\"token comment\"># \u6ce8\u610f\u529b\u5206\u6570\u77e9\u9635<\/span><br \/>\n    attn<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    q<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    rel_pos_h<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    rel_pos_w<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    q_size<span class=\"token punctuation\">:<\/span> Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    k_size<span class=\"token punctuation\">:<\/span> Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span><br \/>\n<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># S,S<\/span><br \/>\n    q_h<span class=\"token punctuation\">,<\/span> q_w <span class=\"token operator\">&#061;<\/span> q_size<br \/>\n    k_h<span class=\"token punctuation\">,<\/span> k_w <span class=\"token operator\">&#061;<\/span> k_size<br \/>\n    <span class=\"token comment\"># rel_pos_h -&gt; 2S-1\u00d7Epos<\/span><br \/>\n    <span class=\"token comment\"># \u67e5\u8be2&#xff08;query&#xff09;\u548c\u952e&#xff08;key&#xff09;\u5728\u9ad8\u5ea6\u65b9\u5411\u4e0a\u7684\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n    Rh <span class=\"token operator\">&#061;<\/span> get_rel_pos<span class=\"token punctuation\">(<\/span>q_h<span class=\"token punctuation\">,<\/span> k_h<span class=\"token punctuation\">,<\/span> rel_pos_h<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u67e5\u8be2&#xff08;query&#xff09;\u548c\u952e&#xff08;key&#xff09;\u5728\u5bbd\u5ea6\u65b9\u5411\u4e0a\u7684\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n    Rw <span class=\"token operator\">&#061;<\/span> get_rel_pos<span class=\"token punctuation\">(<\/span>q_w<span class=\"token punctuation\">,<\/span> k_w<span class=\"token punctuation\">,<\/span> rel_pos_w<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u91cd\u5851q\u4e3a(B, q_h, q_w, dim)<\/span><br \/>\n    B<span class=\"token punctuation\">,<\/span> _<span class=\"token punctuation\">,<\/span> dim <span class=\"token operator\">&#061;<\/span> q<span class=\"token punctuation\">.<\/span>shape<br \/>\n    r_q <span class=\"token operator\">&#061;<\/span> q<span class=\"token punctuation\">.<\/span>reshape<span class=\"token punctuation\">(<\/span>B<span class=\"token punctuation\">,<\/span> q_h<span class=\"token punctuation\">,<\/span> q_w<span class=\"token punctuation\">,<\/span> dim<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u8ba1\u7b97\u76f8\u5bf9\u4f4d\u7f6e\u52a0\u6743<\/span><br \/>\n    <span class=\"token comment\"># \u8ba1\u7b97rel_h\u548crel_w&#xff0c;\u8fd9\u4e24\u4e2a\u5f20\u91cf\u8868\u793a\u5728\u6bcf\u4e2a\u4f4d\u7f6e\u4e0a&#xff0c;\u67e5\u8be2\u4e0e\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u7684\u52a0\u6743\u548c<\/span><br \/>\n    <span class=\"token comment\"># B,q_h,q_w,k_h<\/span><br \/>\n    rel_h <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>einsum<span class=\"token punctuation\">(<\/span><span class=\"token string\">&#034;bhwc,hkc-&gt;bhwk&#034;<\/span><span class=\"token punctuation\">,<\/span> r_q<span class=\"token punctuation\">,<\/span> Rh<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># B,q_h, q_w, k_w<\/span><br \/>\n    rel_w <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>einsum<span class=\"token punctuation\">(<\/span><span class=\"token string\">&#034;bhwc,wkc-&gt;bhwk&#034;<\/span><span class=\"token punctuation\">,<\/span> r_q<span class=\"token punctuation\">,<\/span> Rw<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5408\u5e76\u6ce8\u610f\u529b\u5206\u6570\u548c\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n    <span class=\"token comment\"># \u5c06attn\u91cd\u5851\u4e3a(B, q_h, q_w, k_h, k_w)&#xff0c;\u7136\u540e\u4e0erel_h\u548crel_w\u6309\u5143\u7d20\u76f8\u52a0<\/span><br \/>\n    <span class=\"token comment\"># \u5c06attn\u91cd\u5851\u4e3a(B, q_h, q_w, k_h, k_w)&#xff0c;\u7136\u540e\u4e0erel_h\u548crel_w\u6309\u5143\u7d20\u76f8\u52a0<\/span><br \/>\n    attn <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span><br \/>\n    <span class=\"token comment\"># B,q_h, q_w, k_h, k_w<\/span><br \/>\n        attn<span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>B<span class=\"token punctuation\">,<\/span> q_h<span class=\"token punctuation\">,<\/span> q_w<span class=\"token punctuation\">,<\/span> k_h<span class=\"token punctuation\">,<\/span> k_w<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#043;<\/span> rel_h<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#043;<\/span> rel_w<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>B<span class=\"token punctuation\">,<\/span> q_h <span class=\"token operator\">*<\/span> q_w<span class=\"token punctuation\">,<\/span> k_h <span class=\"token operator\">*<\/span> k_w<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token keyword\">return<\/span> attn<\/p>\n<p>Multi-Head Attention\u6a21\u5757\u4e3a\u6ce8\u610f\u529b\u7279\u5f81\u5d4c\u5165\u4e86\u76f8\u5bf9\u4f4d\u7f6e\u7279\u5f81(add_decomposed_rel_pos)&#xff1a;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074944-680f331827827.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<h5>Neck Convolution<\/h5>\n<p>\u6700\u540e&#xff0c;\u901a\u8fc7\u4e24\u5c42\u5377\u79ef&#xff08;Neck&#xff09;\u5c06\u901a\u9053\u6570\u964d\u4f4e\u81f3256&#xff0c;\u751f\u6210\u6700\u7ec8\u7684Image Embedding\u3002\u5176\u7ed3\u6784\u56fe\u5982\u4e0b\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074944-680f33186b4b4.png\" alt=\"\u56fe\u7247\" \/><\/p>\n<p>\u4ee3\u7801\u5b9e\u73b0\u5982\u4e0b&#xff1a;<\/p>\n<p><span class=\"token comment\"># neck: nn.Sequential&#xff0c;\u5b83\u5305\u542b\u4e24\u4e2a\u5377\u79ef\u5c42\u548c\u4e24\u4e2aLayerNorm2d&#xff09;<\/span><br \/>\nself<span class=\"token punctuation\">.<\/span>neck <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Sequential<span class=\"token punctuation\">(<\/span><br \/>\n    <span class=\"token comment\"># 1&#215;1\u7684\u5377\u79ef\u5c42&#xff0c;\u7528\u4e8e\u5c06\u8f93\u5165\u901a\u9053\u6570\u4eceembed_dim\u51cf\u5c0f\u5230out_chans<\/span><br \/>\n    <span class=\"token comment\"># 1&#215;1\u5377\u79ef\u4e3b\u8981\u7528\u4e8e\u901a\u9053\u95f4\u7684\u4fe1\u606f\u878d\u5408&#xff0c;\u800c\u4e0d\u6539\u53d8\u7279\u5f81\u56fe\u7684\u7a7a\u95f4\u5c3a\u5bf8<\/span><br \/>\n    nn<span class=\"token punctuation\">.<\/span>Conv2d<span class=\"token punctuation\">(<\/span><br \/>\n        embed_dim<span class=\"token punctuation\">,<\/span><br \/>\n        out_chans<span class=\"token punctuation\">,<\/span><br \/>\n        kernel_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u4e0d\u4f7f\u7528\u504f\u7f6e\u9879<\/span><br \/>\n        bias<span class=\"token operator\">&#061;<\/span><span class=\"token boolean\">False<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u5f52\u4e00\u5316\u5c42&#xff0c;\u7528\u4e8e\u89c4\u8303\u5316\u8f93\u51fa\u901a\u9053\u7684\u5747\u503c\u548c\u65b9\u5dee&#xff0c;\u63d0\u9ad8\u6a21\u578b\u7684\u7a33\u5b9a\u6027\u548c\u6536\u655b\u901f\u5ea6<\/span><br \/>\n    <span class=\"token comment\"># out_chans&#xff1a;\u5f52\u4e00\u5316\u5c42\u7684\u901a\u9053\u6570<\/span><br \/>\n    LayerNorm2d<span class=\"token punctuation\">(<\/span>out_chans<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># 3&#215;3\u7684\u5377\u79ef\u5c42<\/span><br \/>\n    nn<span class=\"token punctuation\">.<\/span>Conv2d<span class=\"token punctuation\">(<\/span><br \/>\n        <span class=\"token comment\"># \u4f7f\u7528out_chans\u4f5c\u4e3a\u8f93\u5165\u548c\u8f93\u51fa\u901a\u9053\u6570<\/span><br \/>\n        out_chans<span class=\"token punctuation\">,<\/span><br \/>\n        out_chans<span class=\"token punctuation\">,<\/span><br \/>\n        kernel_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u8f93\u5165\u548c\u8f93\u51fa\u7684\u7279\u5f81\u56fe\u5c3a\u5bf8\u4fdd\u6301\u4e0d\u53d8&#xff0c;\u907f\u514d\u5c3a\u5bf8\u6536\u7f29<\/span><br \/>\n        padding<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u4e0d\u4f7f\u7528\u504f\u7f6e<\/span><br \/>\n        bias<span class=\"token operator\">&#061;<\/span><span class=\"token boolean\">False<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u7b2c\u4e8c\u4e2a\u5f52\u4e00\u5316\u5c42&#xff0c;\u518d\u6b21\u5bf9\u8f93\u51fa\u8fdb\u884c\u89c4\u8303\u5316<\/span><br \/>\n    LayerNorm2d<span class=\"token punctuation\">(<\/span>out_chans<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n<span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token comment\"># \u5f52\u4e00\u5316<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">LayerNorm2d<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> num_channels<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> eps<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">float<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">1e-6<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u521b\u5efa\u4e86\u4e24\u4e2a\u53ef\u5b66\u4e60\u7684\u53c2\u6570&#xff1a;weight\u548cbias<\/span><br \/>\n        <span class=\"token comment\"># weight\u521d\u59cb\u5316\u4e3a\u51681&#xff0c;bias\u521d\u59cb\u5316\u4e3a\u51680<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>weight <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Parameter<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>ones<span class=\"token punctuation\">(<\/span>num_channels<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>bias <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Parameter<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>zeros<span class=\"token punctuation\">(<\/span>num_channels<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>eps <span class=\"token operator\">&#061;<\/span> eps<\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u6cbf\u7740\u901a\u9053\u7ef4\u5ea6\u6c42\u5747\u503c&#xff0c;keepdim&#061;True\u4fdd\u7559\u7ef4\u5ea6&#xff0c;\u4f7f\u5f97u\u7684\u5f62\u72b6\u4e0ex\u76f8\u540c&#xff0c;\u9664\u4e86\u901a\u9053\u7ef4\u5ea6\u7684\u5927\u5c0f\u4e3a1<\/span><br \/>\n        u <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>mean<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> keepdim<span class=\"token operator\">&#061;<\/span><span class=\"token boolean\">True<\/span><span class=\"token punctuation\">)<\/span>                 <span class=\"token comment\"># dim&#061;1\u7ef4\u5ea6\u6c42\u5747\u503c\u5e76\u4fdd\u7559\u901a\u9053<\/span><br \/>\n        <span class=\"token comment\"># \u8ba1\u7b97\u6807\u51c6\u5316\u56e0\u5b50 s&#xff0c;\u5373\u51cf\u53bb\u5747\u503c\u540e\u7684\u5e73\u65b9\u5dee\u7684\u5e73\u5747\u503c&#xff0c;\u4e5f\u4fdd\u7559\u901a\u9053\u7ef4\u5ea6<\/span><br \/>\n        s <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span>x <span class=\"token operator\">&#8211;<\/span> u<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token builtin\">pow<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>mean<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> keepdim<span class=\"token operator\">&#061;<\/span><span class=\"token boolean\">True<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5f52\u4e00\u5316&#xff0c;\u5c06\u6bcf\u4e2a\u50cf\u7d20\u7684\u503c\u51cf\u53bb\u5747\u503c u&#xff0c;\u7136\u540e\u9664\u4ee5\u6807\u51c6\u5dee\u7684\u5e73\u65b9\u6839\u52a0\u4e0a\u4e00\u4e2a\u5c0f\u7684\u5e38\u6570 eps \u4ee5\u4fdd\u8bc1\u6570\u503c\u7a33\u5b9a\u6027<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span>x <span class=\"token operator\">&#8211;<\/span> u<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">\/<\/span> torch<span class=\"token punctuation\">.<\/span>sqrt<span class=\"token punctuation\">(<\/span>s <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>eps<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5e94\u7528\u53ef\u5b66\u4e60\u7684\u6743\u91cd\u548c\u504f\u7f6e<\/span><br \/>\n        x <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>weight<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">*<\/span> x <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>bias<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">]<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> x<\/p>\n<h4>Prompt Encoder<\/h4>\n<p>SAM\u6a21\u578b\u4e2dPrompt Encoder\u7f51\u7edc\u7ed3\u6784\u5982\u4e0b\u56fe\u6240\u793a\u3002\u4e3b\u8981\u5305\u62ec\u4e09\u6b65\u9aa4&#xff1a;<\/p>\n<ul>\n<li>\n<p>Embed_Points&#xff1a;\u6807\u8bb0\u70b9\u7f16\u7801(\u6807\u8bb0\u70b9\u7531\u70b9\u8f6c\u53d8\u4e3a\u5411\u91cf)<\/p>\n<\/li>\n<li>\n<p>Embed_Boxes&#xff1a;\u6807\u8bb0\u6846\u7f16\u7801(\u6807\u8bb0\u6846\u7531\u70b9\u8f6c\u53d8\u4e3a\u5411\u91cf)<\/p>\n<\/li>\n<li>\n<p>Embed_Masks&#xff1a;mask\u7f16\u7801(mask\u4e0b\u91c7\u6837\u4fdd\u8bc1\u4e0eImage Encoder\u8f93\u51fa\u4e00\u81f4)<\/p>\n<\/li>\n<\/ul>\n<p><img decoding=\"async\" src=\"2025-04-28gksv0e0ispk.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<h5>Embed_Points<\/h5>\n<p>Embed_Points\u7ed3\u6784\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-2810tfclcfyu0.png\" alt=\"\u56fe\u7247\" \/><\/p>\n<p>\u6807\u8bb0\u70b9\u9884\u5904\u7406&#xff0c;\u5c06channel\u75312\u53d8\u4e3aembed_dim(MatMul:forward_with_coords)&#xff0c;\u7136\u540e\u518d\u52a0\u4e0a\u4f4d\u7f6e\u7f16\u7801\u6743\u91cd\u3002\u5176\u4e2d&#xff0c;<\/p>\n<ul>\n<li>\n<p>2&#xff1a;\u5750\u6807(h,w)<\/p>\n<\/li>\n<li>\n<p>embed_dim&#xff1a;\u63d0\u793a\u7f16\u7801\u7684channel<\/p>\n<\/li>\n<\/ul>\n<p>\u300c\u4ee3\u7801\u5b9e\u73b0&#xff1a;\u300d<\/p>\n<p><span class=\"token comment\"># \u5c06\u8f93\u5165\u7684\u70b9\u5750\u6807\u548c\u5bf9\u5e94\u7684\u6807\u7b7e\u8f6c\u5316\u4e3a\u9ad8\u7ef4\u7684\u5d4c\u5165\u8868\u793a&#xff0c;\u4ee5\u4fbf\u4e8e\u540e\u7eed\u7684\u6a21\u578b\u5904\u7406<\/span><br \/>\n<span class=\"token keyword\">def<\/span> <span class=\"token function\">_embed_points<\/span><span class=\"token punctuation\">(<\/span><br \/>\n    self<span class=\"token punctuation\">,<\/span><br \/>\n    points<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    labels<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    pad<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">bool<\/span><span class=\"token punctuation\">,<\/span><br \/>\n<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u8f93\u5165\u7684\u70b9\u5750\u6807points\u7684\u6bcf\u4e2a\u5750\u6807\u503c\u589e\u52a00.5&#xff0c;\u4ee5\u5c06\u5750\u6807\u4ece\u50cf\u7d20\u7684\u5de6\u4e0a\u89d2\u79fb\u52a8\u5230\u50cf\u7d20\u4e2d\u5fc3<\/span><br \/>\n    points <span class=\"token operator\">&#061;<\/span> points <span class=\"token operator\">&#043;<\/span> <span class=\"token number\">0.5<\/span><br \/>\n    <span class=\"token comment\"># points\u548cboxes\u8054\u5408\u5219\u4e0d\u9700\u8981pad<\/span><br \/>\n    <span class=\"token keyword\">if<\/span> pad<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u5728\u70b9\u5750\u6807 points \u548c\u6807\u7b7e labels \u4e2d\u6dfb\u52a0\u4e00\u4e2a\u586b\u5145\u9879<\/span><br \/>\n        <span class=\"token comment\"># \u4ee5\u4fdd\u6301\u6279\u6b21\u5904\u7406\u7684\u4e00\u81f4\u6027&#xff0c;\u5373\u4f7f\u67d0\u4e9b\u6837\u672c\u7684\u70b9\u6570\u91cf\u5c11\u4e8e\u6700\u5927\u6570\u91cf\u3002<\/span><br \/>\n        <span class=\"token comment\"># \u586b\u5145\u7684\u70b9\u5750\u6807\u4e3a(0,0)&#xff0c;\u6807\u7b7e\u4e3a-1<\/span><br \/>\n        padding_point <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>zeros<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">(<\/span>points<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> device<span class=\"token operator\">&#061;<\/span>points<span class=\"token punctuation\">.<\/span>device<span class=\"token punctuation\">)<\/span>  <span class=\"token comment\"># B,1,2<\/span><br \/>\n        padding_label <span class=\"token operator\">&#061;<\/span> <span class=\"token operator\">&#8211;<\/span>torch<span class=\"token punctuation\">.<\/span>ones<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">(<\/span>labels<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> device<span class=\"token operator\">&#061;<\/span>labels<span class=\"token punctuation\">.<\/span>device<span class=\"token punctuation\">)<\/span>     <span class=\"token comment\"># B,1<\/span><br \/>\n        points <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>cat<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>points<span class=\"token punctuation\">,<\/span> padding_point<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span>                          <span class=\"token comment\"># B,N&#043;1,2<\/span><br \/>\n        labels <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>cat<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>labels<span class=\"token punctuation\">,<\/span> padding_label<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span>                          <span class=\"token comment\"># B,N&#043;1<\/span><br \/>\n    <span class=\"token comment\"># \u6839\u636e\u8c03\u6574\u540e\u7684\u70b9\u5750\u6807\u548c\u8f93\u5165\u56fe\u50cf\u7684\u5c3a\u5bf8\u751f\u6210\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n    <span class=\"token comment\"># \u751f\u6210\u7684\u5d4c\u5165\u7ef4\u5ea6&#xff1a;B,N&#043;1,2f<\/span><br \/>\n    <span class=\"token comment\"># 2f \u8868\u793a\u6bcf\u4e2a\u70b9\u4f4d\u7f6e\u7f16\u7801\u7684\u7ef4\u5ea6&#xff0c;\u662f\u901a\u8fc7\u67d0\u79cd\u51fd\u6570&#xff08;\u5982\u6b63\u5f26\u6216\u4f59\u5f26\u51fd\u6570&#xff09;\u4ece\u539f\u59cb\u76842D\u5750\u6807\u6269\u5c55\u800c\u6765<\/span><br \/>\n    point_embedding <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>pe_layer<span class=\"token punctuation\">.<\/span>forward_with_coords<span class=\"token punctuation\">(<\/span>points<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>input_image_size<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u6839\u636e\u6807\u7b7e labels \u7684\u503c&#xff0c;\u5bf9\u6bcf\u4e2a\u70b9\u7684\u5d4c\u5165\u8fdb\u884c\u8c03\u6574\u3002<\/span><\/p>\n<p>    <span class=\"token comment\"># labels\u4e3a-1\u662f\u975e\u6807\u8bb0\u70b9&#xff0c;\u8bbe\u4e3a\u975e\u6807\u8bb0\u70b9\u6743\u91cd<\/span><br \/>\n    point_embedding<span class=\"token punctuation\">[<\/span>labels <span class=\"token operator\">&#061;&#061;<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">0.0<\/span><br \/>\n    point_embedding<span class=\"token punctuation\">[<\/span>labels <span class=\"token operator\">&#061;&#061;<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#043;&#061;<\/span> self<span class=\"token punctuation\">.<\/span>not_a_point_embed<span class=\"token punctuation\">.<\/span>weight<br \/>\n    <span class=\"token comment\"># labels\u4e3a0\u662f\u80cc\u666f\u70b9&#xff0c;\u52a0\u4e0a\u80cc\u666f\u70b9\u6743\u91cd<\/span><br \/>\n    point_embedding<span class=\"token punctuation\">[<\/span>labels <span class=\"token operator\">&#061;&#061;<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#043;&#061;<\/span> self<span class=\"token punctuation\">.<\/span>point_embeddings<span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">.<\/span>weight<br \/>\n    <span class=\"token comment\"># labels\u4e3a1\u662f\u76ee\u6807\u70b9&#xff0c;\u52a0\u4e0a\u76ee\u6807\u70b9\u6743\u91cd<\/span><br \/>\n    point_embedding<span class=\"token punctuation\">[<\/span>labels <span class=\"token operator\">&#061;&#061;<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#043;&#061;<\/span> self<span class=\"token punctuation\">.<\/span>point_embeddings<span class=\"token punctuation\">[<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">.<\/span>weight<br \/>\n    <span class=\"token keyword\">return<\/span> point_embedding<\/p>\n<h5>Embed_Boxes<\/h5>\n<p>Embed_Boxes\u7ed3\u6784\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-28zqk5hpysy4z.png\" alt=\"\u5728\u8fd9\u91cc\u63d2\u5165\u56fe\u7247\u63cf\u8ff0\" \/><\/p>\n<p>\u6807\u8bb0\u6846&#xff08;Bounding Box&#xff09;\u4e00\u822c\u6709\u4e24\u4e2a\u70b9&#xff0c;\u7f16\u7801\u6b65\u9aa4\u5982\u4e0b&#xff1a;<\/p>\n<li>\n<p>\u5c06\u8f93\u5165\u7684\u8fb9\u754c\u6846\u5750\u6807\u5f20\u91cfboxes\u4eceBxNx4\u8f6c\u6362\u4e3aBxNx2x2&#xff1b;<\/p>\n<\/li>\n<li>\n<p>\u518d\u4f7f\u7528point embedding\u7f16\u7801\u7684\u65b9\u5f0f&#xff0c;\u5f97\u5230corner_embedding&#xff1b;<\/p>\n<\/li>\n<li>\n<p>\u52a0\u4e0a\u4e4b\u524d\u751f\u6210\u7684\u53ef\u5b66\u4e60\u7684embeding\u5411\u91cf\u3002<\/p>\n<\/li>\n<p>\u6700\u540e\u8f93\u51fa\u7684corner_embedding\u5927\u5c0f\u4e3aNx2x256\u3002<\/p>\n<p>\u300c\u4ee3\u7801\u5b9e\u73b0&#xff1a;\u300d<\/p>\n<p><span class=\"token comment\"># \u5c06\u8f93\u5165\u7684\u8fb9\u754c\u6846&#xff08;boxes&#xff09;\u8f6c\u6362\u4e3a\u9ad8\u7ef4\u7684\u5d4c\u5165\u8868\u793a<\/span><br \/>\n<span class=\"token keyword\">def<\/span> <span class=\"token function\">_embed_boxes<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> boxes<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u5750\u6807\u4ece\u50cf\u7d20\u7684\u5de6\u4e0a\u89d2\u79fb\u52a8\u5230\u50cf\u7d20\u4e2d\u5fc3<\/span><br \/>\n    boxes <span class=\"token operator\">&#061;<\/span> boxes <span class=\"token operator\">&#043;<\/span> <span class=\"token number\">0.5<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u8f93\u5165\u7684\u8fb9\u754c\u6846\u5750\u6807\u5f20\u91cfboxes\u4eceBxN*4\u8f6c\u6362\u4e3aB*Nx2x2<\/span><br \/>\n    <span class=\"token comment\"># \u5176\u4e2dB\u662f\u6279\u6b21\u5927\u5c0f&#xff0c;N\u662f\u6bcf\u4e2a\u6837\u672c\u4e2d\u7684\u8fb9\u754c\u6846\u6570\u91cf<\/span><br \/>\n    coords <span class=\"token operator\">&#061;<\/span> boxes<span class=\"token punctuation\">.<\/span>reshape<span class=\"token punctuation\">(<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5bf9\u6bcf\u4e2a\u8fb9\u754c\u6846\u7684\u89d2\u70b9\u5750\u6807\u8fdb\u884c\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n    corner_embedding <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>pe_layer<span class=\"token punctuation\">.<\/span>forward_with_coords<span class=\"token punctuation\">(<\/span>coords<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>input_image_size<span class=\"token punctuation\">)<\/span>    <span class=\"token comment\">#<\/span><br \/>\n    <span class=\"token comment\"># \u5206\u522b\u5bf9\u6bcf\u4e2a\u8fb9\u754c\u6846\u7684\u8d77\u59cb\u70b9\u548c\u672b\u5c3e\u70b9\u7684\u5d4c\u5165\u5411\u91cf\u52a0\u4e0a\u7279\u5b9a\u7684\u6743\u91cd<\/span><br \/>\n    corner_embedding<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#043;&#061;<\/span> self<span class=\"token punctuation\">.<\/span>point_embeddings<span class=\"token punctuation\">[<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">.<\/span>weight<br \/>\n    corner_embedding<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#043;&#061;<\/span> self<span class=\"token punctuation\">.<\/span>point_embeddings<span class=\"token punctuation\">[<\/span><span class=\"token number\">3<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">.<\/span>weight<br \/>\n    <span class=\"token comment\"># \u8fd4\u56de\u52a0\u6743\u540e\u5d4c\u5165\u5411\u91cf&#xff0c;\u5f62\u72b6\u4e3a B*Nx2xembed_dim&#xff0c;\u5176\u4e2d embed_dim \u662f\u4f4d\u7f6e\u7f16\u7801\u7684\u7ef4\u5ea6<\/span><br \/>\n    <span class=\"token keyword\">return<\/span> corner_embedding<\/p>\n<h5>Embed_Mask<\/h5>\n<p>mask\u63d0\u793a\u5141\u8bb8\u6211\u4eec\u76f4\u63a5\u5728\u539f\u56fe\u4e0a\u6307\u793a\u611f\u5174\u8da3\u533a\u57df\u6765\u5f15\u5bfc\u6a21\u578b\u3002\u8fd9\u4e9bmask\u901a\u8fc7\u5377\u79ef\u64cd\u4f5c\u88ab\u8f6c\u6362\u4e3a\u4e0e\u56fe\u50cf\u5d4c\u5165\u7a7a\u95f4\u76f8\u5339\u914d\u7684\u7279\u5f81&#xff0c;\u7136\u540e\u4e0e\u56fe\u50cf\u5d4c\u5165\u76f8\u52a0\u7ed3\u5408&#xff0c;\u4e3a\u6a21\u578b\u63d0\u4f9b\u5206\u5272\u7684\u7cbe\u786e\u4f4d\u7f6e\u4fe1\u606f\u3002<\/p>\n<p>\u5982\u679c\u6ca1\u6709\u4f7f\u7528mask\u63d0\u793a&#xff0c;\u5219\u5c06\u4e00\u7ec4\u53ef\u5b66\u4e60\u5411\u91cf(no_mask_embed,1*256)expand\u4e3a1&#215;256\u00d764\u00d764\u540e\u66ff\u4ee3&#xff0c;\u4f7f\u5f97\u5728\u5904\u7406\u5e8f\u5217\u6570\u636e\u65f6&#xff0c;\u5373\u4f7f\u6ca1\u6709\u5177\u4f53\u7684mask\u4fe1\u606f&#xff0c;\u4e5f\u80fd\u6709\u4e00\u4e2a\u7edf\u4e00\u7684\u5904\u7406\u65b9\u5f0f\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-28x1d5vyrc4gb.png\" alt=\"\u56fe\u7247\" \/><\/p>\n<p><span class=\"token comment\"># \u5728PromptEncoder\u7684forward\u5b9a\u4e49<\/span><br \/>\n<span class=\"token triple-quoted-string string\">&#039;&#039;&#039;<br \/>\n\u9996\u5148\u83b7\u53d6no_mask_embed\u6743\u91cd\u77e9\u9635&#xff0c;\u5e76\u5c06\u5176\u91cd\u5851\u6210\u4e00\u4e2a\u5f62\u72b6\u4e3a(1, num_embeddings, 1, 1)\u7684\u56db\u7ef4\u5f20\u91cf\u3002<\/p>\n<p>\u518d\u5229\u7528.expand\u65b9\u6cd5\u5c06\u8fd9\u4e2a\u5f20\u91cf\u6269\u5c55\u5230\u4e0e\u56fe\u50cf\u7f16\u7801\u76f8\u540c\u7684\u5c3a\u5bf8\u3002bs\u662fbatch\u5927\u5c0f&#xff0c;-1\u662f\u4e00\u4e2a\u5360\u4f4d\u7b26&#xff0c;\u5b83\u4f1a\u81ea\u52a8\u8ba1\u7b97\u51fa<br \/>\nnum_embeddings\u7684\u503c\u4ee5\u4fdd\u6301\u5f20\u91cf\u7684\u5143\u7d20\u603b\u6570\u4e0d\u53d8\u3002self.image_embedding_size[0]\u548cself.image_embedding_size[1]\u5206\u522b\u8868\u793a\u56fe\u50cf\u7f16\u7801\u7684\u5bbd\u5ea6\u548c\u9ad8\u5ea6\u3002<br \/>\n&#039;&#039;&#039;<\/span><br \/>\nself<span class=\"token punctuation\">.<\/span>no_mask_embed <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Embedding<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> embed_dim<span class=\"token punctuation\">)<\/span>      <span class=\"token comment\"># embed_dim&#061;256<\/span><br \/>\ndense_embeddings <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>no_mask_embed<span class=\"token punctuation\">.<\/span>weight<span class=\"token punctuation\">.<\/span>reshape<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>expand<span class=\"token punctuation\">(<\/span><br \/>\n                bs<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>image_embedding_size<span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>image_embedding_size<span class=\"token punctuation\">[<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p><span class=\"token punctuation\">)<\/span><\/p>\n<p>\u5982\u679c\u6709\u914d\u7f6emask&#xff0c;Embed_Masks\u7ed3\u6784\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-28v4c445cupp1.png\" alt=\"\u5728\u8fd9\u91cc\u63d2\u5165\u56fe\u7247\u63cf\u8ff0\" width=\"200\" \/><\/p>\n<p>\u5df2\u77e5\u8f93\u5165mask\u662fNx1x256x256&#xff0c;\u7ecf\u8fc73\u5c42\u5377\u79ef&#xff0c;\u6700\u540e\u5f97\u5230\u4e0eImage Embedding\u4e00\u6837\u7684size&#xff1a;<\/p>\n<p>\u9996\u5148&#xff0c;mask\u8fdb\u5165\u4e00\u4e2a1x2x2x4\u7684\u5377\u79ef&#xff0c;stride&#061;2&#xff1b;LN&#xff1b;\u518d\u8fdb\u5165\u4e00\u4e2a4x2x2x16\u7684\u5377\u79ef&#xff0c;stride&#061;2&#xff1b;LN&#xff1b;\u6700\u540e\u518d\u8fdb\u5165\u4e00\u4e2a16x1x1x256\u7684\u5377\u79ef&#xff1b;\u5f97\u5230\u6700\u540e\u7684mask_embedding\u7684size\u4e3aNx256x64x64&#xff0c;\u6700\u7ec8mask_embedding\u4f5c\u4e3adense_embedding\u8f93\u51fa&#xff0c;\u5927\u5c0f\u4e3aNx256x64x64\u3002<\/p>\n<p>mask\u7684\u8f93\u51fa\u5c3a\u5bf8\u662fImage Encoder\u6a21\u5757\u8f93\u51fa\u7684\u56fe\u50cf\u7f16\u7801\u5c3a\u5bf8\u76844\u500d&#xff0c;\u56e0\u6b64\u4e3a\u4e86\u4fdd\u6301\u4e00\u81f4&#xff0c;\u9700\u89814\u500d\u4e0b\u91c7\u6837\u3002<\/p>\n<p>\u300c\u4ee3\u7801\u5b9e\u73b0\u300d<\/p>\n<p><span class=\"token comment\"># \u5c06\u8f93\u5165\u7684\u63a9\u6a21&#xff08;mask&#xff09;\u5f20\u91cf\u8f6c\u6362\u4e3a\u4e00\u4e2a\u4f4e\u5206\u8fa8\u7387\u7684\u5d4c\u5165\u8868\u793a<\/span><br \/>\n<span class=\"token comment\"># \u63a9\u6a21 masks \u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a BxCxHxW \u7684\u5f20\u91cf<\/span><br \/>\n<span class=\"token comment\"># \u5176\u4e2d B \u662f\u6279\u6b21\u5927\u5c0f&#xff0c;C \u662f\u901a\u9053\u6570&#xff08;\u901a\u5e38\u4e3a1&#xff0c;\u56e0\u4e3a\u63a9\u6a21\u901a\u5e38\u53ea\u6709\u4e00\u901a\u9053&#xff09;&#xff0c;H \u548c W \u5206\u522b\u662f\u9ad8\u5ea6\u548c\u5bbd\u5ea6\u3002<\/span><br \/>\n<span class=\"token keyword\">def<\/span> <span class=\"token function\">_embed_masks<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> masks<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># mask\u4e0b\u91c7\u68374\u500d<\/span><br \/>\n    mask_embedding <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>mask_downscaling<span class=\"token punctuation\">(<\/span>masks<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u8fd4\u56de\u4e0b\u91c7\u6837\u5e76\u8f6c\u6362\u540e\u7684\u63a9\u6a21\u5d4c\u5165&#xff0c;\u5176\u5f62\u72b6\u4e3a B*embed_dim*H&#039;*W&#039;,\u5176\u4e2d H&#039; \u548c W&#039; \u662f\u4e0b\u91c7\u6837\u540e\u7684\u9ad8\u5ea6\u548c\u5bbd\u5ea6<\/span><br \/>\n    <span class=\"token keyword\">return<\/span> mask_embedding<\/p>\n<p><span class=\"token comment\"># mask_downscaling\u5305\u62ec\u591a\u4e2a\u5377\u79ef\u5c42\u3001\u5c42\u5f52\u4e00\u5316&#xff08;LayerNorm2d&#xff09;\u548c\u6fc0\u6d3b\u51fd\u6570&#xff0c;\u76ee\u7684\u662f\u51cf\u5c11\u63a9\u6a21\u7684\u7a7a\u95f4\u7ef4\u5ea6&#xff0c;\u540c\u65f6\u589e\u52a0\u901a\u9053\u7ef4\u5ea6<\/span><br \/>\nself<span class=\"token punctuation\">.<\/span>mask_downscaling <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Sequential<span class=\"token punctuation\">(<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u901a\u9053\u6570\u4ece1\u51cf\u5c11\u5230mask_in_chans\/\/4&#xff0c;\u540c\u65f6\u4f7f\u75282&#215;2\u7684\u5377\u79ef\u6838\u548c\u6b65\u957f2\u8fdb\u884c\u4e0b\u91c7\u6837&#xff0c;\u964d\u4f4e\u4e86\u7a7a\u95f4\u5206\u8fa8\u7387<\/span><br \/>\n    nn<span class=\"token punctuation\">.<\/span>Conv2d<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> mask_in_chans <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">,<\/span> kernel_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> stride<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u89c4\u8303\u5316\u901a\u9053\u7ef4\u5ea6\u4e0a\u7684\u7279\u5f81<\/span><br \/>\n    LayerNorm2d<span class=\"token punctuation\">(<\/span>mask_in_chans <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u6fc0\u6d3b\u51fd\u6570&#xff0c;\u5f15\u5165\u975e\u7ebf\u6027<\/span><br \/>\n    activation<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u901a\u9053\u6570\u6062\u590d\u5230 mask_in_chans&#xff0c;\u518d\u6b21\u4f7f\u75282&#215;2\u7684\u5377\u79ef\u6838\u548c\u6b65\u957f2\u8fdb\u884c\u4e0b\u91c7\u6837&#xff0c;\u8fdb\u4e00\u6b65\u964d\u4f4e\u7a7a\u95f4\u5206\u8fa8\u7387<\/span><br \/>\n    nn<span class=\"token punctuation\">.<\/span>Conv2d<span class=\"token punctuation\">(<\/span>mask_in_chans <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">,<\/span> mask_in_chans<span class=\"token punctuation\">,<\/span> kernel_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> stride<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># LayerNorm2d \u5c42\u548c\u6fc0\u6d3b\u51fd\u6570<\/span><br \/>\n    LayerNorm2d<span class=\"token punctuation\">(<\/span>mask_in_chans<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    activation<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u901a\u9053\u6570\u589e\u52a0\u5230 embed_dim&#xff0c;\u901a\u5e38\u662f\u4e3a\u4e86\u4e0e\u6a21\u578b\u7684\u5176\u4ed6\u90e8\u5206\u4fdd\u6301\u4e00\u81f4<\/span><br \/>\n    nn<span class=\"token punctuation\">.<\/span>Conv2d<span class=\"token punctuation\">(<\/span>mask_in_chans<span class=\"token punctuation\">,<\/span> embed_dim<span class=\"token punctuation\">,<\/span> kernel_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token punctuation\">)<\/span><\/p>\n<p>\u300cPositionEmbeddingRandom\u300d<\/p>\n<p>\u7528\u4e8e\u5c06\u6807\u8bb0\u70b9\u548c\u6807\u8bb0\u6846\u7684\u5750\u6807\u8fdb\u884c\u63d0\u793a\u7f16\u7801\u9884\u5904\u7406\u3002\u5c31\u662f\u5c0664&#215;64\u4e2a\u5750\u6807\u70b9\u5f52\u4e00\u5316\u540e&#xff0c;\u4e0e\u968f\u673a\u9ad8\u65af\u77e9\u9635\u76f8\u4e58(2&#215;128)&#xff0c;\u518d\u5c06\u7ed3\u679c\u5206\u522b\u8fdb\u884csin\u548ccos&#xff0c;\u6700\u540e\u518d\u62fc\u5230\u4e00\u8d77&#xff0c;\u8f93\u51fa\u7684\u5927\u5c0f\u4e3a256x64x64&#xff0c;\u4e0eimage_embedding\u5927\u5c0f\u57fa\u672c\u4e00\u81f4\u4e86\u3002<\/p>\n<p><span class=\"token keyword\">class<\/span> <span class=\"token class-name\">PositionEmbeddingRandom<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token triple-quoted-string string\">&#034;&#034;&#034;<br \/>\n    Positional encoding using random spatial frequencies.<br \/>\n    &#034;&#034;&#034;<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">init<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> num_pos_feats<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">64<\/span><span class=\"token punctuation\">,<\/span> scale<span class=\"token punctuation\">:<\/span> Optional<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">float<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>init<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">if<\/span> scale <span class=\"token keyword\">is<\/span> <span class=\"token boolean\">None<\/span> <span class=\"token keyword\">or<\/span> scale <span class=\"token operator\">&lt;&#061;<\/span> <span class=\"token number\">0.0<\/span><span class=\"token punctuation\">:<\/span><br \/>\n            scale <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">1.0<\/span><br \/>\n        <span class=\"token comment\"># \u6784\u5efa\u4e00\u4e2a2&#215;128\u7684\u968f\u673a\u77e9\u9635\u4f5c\u4e3a\u4f4d\u7f6e\u7f16\u7801\u9ad8\u65af\u77e9\u9635<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>register_buffer<span class=\"token punctuation\">(<\/span><br \/>\n            <span class=\"token string\">&#034;positional_encoding_gaussian_matrix&#034;<\/span><span class=\"token punctuation\">,<\/span><br \/>\n            scale <span class=\"token operator\">*<\/span> torch<span class=\"token punctuation\">.<\/span>randn<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> num_pos_feats<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">_pe_encoding<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> coords<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token triple-quoted-string string\">&#034;&#034;&#034;Positionally encode points that are normalized to [0,1].&#034;&#034;&#034;<\/span><br \/>\n        <span class=\"token comment\"># assuming coords are in [0, 1]^2 square and have d_1 x &#8230; x d_n x 2 shape<\/span><br \/>\n        coords <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">2<\/span> <span class=\"token operator\">*<\/span> coords <span class=\"token operator\">&#8211;<\/span> <span class=\"token number\">1<\/span><\/p>\n<p>        <span class=\"token comment\"># \u77e9\u9635\u4e58\u6cd5&#xff1a;64x64xx2 &#064; 2&#215;128 &#8212;&gt; 64x64x128<\/span><br \/>\n        coords <span class=\"token operator\">&#061;<\/span> coords &#064; self<span class=\"token punctuation\">.<\/span>positional_encoding_gaussian_matrix<br \/>\n        coords <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">2<\/span> <span class=\"token operator\">*<\/span> np<span class=\"token punctuation\">.<\/span>pi <span class=\"token operator\">*<\/span> coords<\/p>\n<p>        <span class=\"token comment\"># outputs d_1 x &#8230; x d_n x C shape<\/span><br \/>\n        <span class=\"token comment\"># cat, \u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u4e0a\u62fc\u63a5&#xff1a;64x64x256<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> torch<span class=\"token punctuation\">.<\/span>cat<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>torch<span class=\"token punctuation\">.<\/span>sin<span class=\"token punctuation\">(<\/span>coords<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> torch<span class=\"token punctuation\">.<\/span>cos<span class=\"token punctuation\">(<\/span>coords<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> size<span class=\"token punctuation\">:<\/span> Tuple<span class=\"token punctuation\">[<\/span><span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token triple-quoted-string string\">&#034;&#034;&#034;Generate positional encoding for a grid of the specified size.&#034;&#034;&#034;<\/span><br \/>\n        h<span class=\"token punctuation\">,<\/span> w <span class=\"token operator\">&#061;<\/span> size<br \/>\n        device<span class=\"token punctuation\">:<\/span> Any <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>positional_encoding_gaussian_matrix<span class=\"token punctuation\">.<\/span>device<\/p>\n<p>        <span class=\"token comment\"># \u6784\u9020\u4e00\u4e2a64&#215;64\u7684\u51681\u77e9\u9635<\/span><br \/>\n        grid <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>ones<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">(<\/span>h<span class=\"token punctuation\">,<\/span> w<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> device<span class=\"token operator\">&#061;<\/span>device<span class=\"token punctuation\">,<\/span> dtype<span class=\"token operator\">&#061;<\/span>torch<span class=\"token punctuation\">.<\/span>float32<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u884c\u3001\u5217\u7d2f\u52a0<\/span><br \/>\n        y_embed <span class=\"token operator\">&#061;<\/span> grid<span class=\"token punctuation\">.<\/span>cumsum<span class=\"token punctuation\">(<\/span>dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span> <span class=\"token number\">0.5<\/span><br \/>\n        x_embed <span class=\"token operator\">&#061;<\/span> grid<span class=\"token punctuation\">.<\/span>cumsum<span class=\"token punctuation\">(<\/span>dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span> <span class=\"token number\">0.5<\/span><\/p>\n<p>        <span class=\"token comment\"># \u884c\u5217\u7d2f\u52a0\u7ed3\u679c\u5f52\u4e00\u5316<\/span><br \/>\n        y_embed <span class=\"token operator\">&#061;<\/span> y_embed <span class=\"token operator\">\/<\/span> h<br \/>\n        x_embed <span class=\"token operator\">&#061;<\/span> x_embed <span class=\"token operator\">\/<\/span> w<\/p>\n<p>        <span class=\"token comment\"># \u884c\u5217\u62fc\u63a5&#xff1a;64x64x2&#xff0c;\u7f16\u7801\u540e\u7684\u7ed3\u679c\u662f64x64x256<\/span><br \/>\n        pe <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>_pe_encoding<span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>stack<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>x_embed<span class=\"token punctuation\">,<\/span> y_embed<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u6700\u540e\u8f93\u51fa256x64x64<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> pe<span class=\"token punctuation\">.<\/span>permute<span class=\"token punctuation\">(<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span>  <span class=\"token comment\"># C x H x W<\/span><\/p>\n<h4>Mask Decoder<\/h4>\n<p>Mask Decoder\u7f51\u7edc\u7ed3\u6784\u53c2\u6570\u914d\u7f6e\u5982\u4e0b\u3002<\/p>\n<p><span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><br \/>\n    self<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token operator\">*<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># transformer\u901a\u9053\u6570<\/span><br \/>\n    transformer_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u7528\u4e8e\u9884\u6d4bmask\u7684Transformer\u7f51\u7edc\u6a21\u5757<\/span><br \/>\n    transformer<span class=\"token punctuation\">:<\/span> nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u6d88\u9664\u63a9\u7801\u6b67\u4e49\u9884\u6d4b\u7684\u63a9\u7801\u6570\u91cf&#xff0c;\u9ed8\u8ba4\u4e3a3<\/span><br \/>\n    num_multimask_outputs<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u6fc0\u6d3b\u51fd\u6570&#xff0c;\u9ed8\u8ba4\u4e3aGELU<\/span><br \/>\n    activation<span class=\"token punctuation\">:<\/span> Type<span class=\"token punctuation\">[<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>GELU<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># MLP\u7528\u4e8e\u9884\u6d4b\u63a9\u6a21\u8d28\u91cf\u7684\u6df1\u5ea6<\/span><br \/>\n    iou_head_depth<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># MLP\u7684\u9690\u85cf\u5c42\u901a\u9053\u6570<\/span><br \/>\n    iou_head_hidden_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">256<\/span><span class=\"token punctuation\">,<\/span><br \/>\n<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>transformer_dim <span class=\"token operator\">&#061;<\/span> transformer_dim <span class=\"token comment\">#\u5b58\u50a8\u4f20\u5165\u7684transformer_dim<\/span><br \/>\n    <span class=\"token comment\"># \u5b58\u50a8\u4f20\u5165\u7684transformer\u6a21\u5757<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>transformer <span class=\"token operator\">&#061;<\/span> transformer<br \/>\n    <span class=\"token comment\"># \u5b58\u50a8\u63a9\u7801\u9884\u6d4b\u7684\u8f93\u51fa\u6570\u91cf<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>num_multimask_outputs <span class=\"token operator\">&#061;<\/span> num_multimask_outputs<br \/>\n    <span class=\"token comment\"># \u7528\u4e8e\u8868\u793aIoU&#xff08;Intersection over Union&#xff09;\u7684\u5d4c\u5165\u5c42&#xff0c;\u5927\u5c0f\u4e3a1\u00d7transformer_dim<\/span><br \/>\n    <span class=\"token comment\"># \u53ef\u5b66\u4e60\u7684iou tokens&#xff1a;1&#215;256<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>iou_token <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Embedding<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> transformer_dim<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5305\u542bIoU token\u5728\u5185\u7684\u603bmask token\u6570\u91cf<\/span><br \/>\n    <span class=\"token comment\"># # num_mask_tokens &#061; 3 &#043; 1 &#061; 4, transformer_dim &#061; 256<\/span><br \/>\n    <span class=\"token comment\"># \u8f93\u51fa\u4e00\u4e2a4&#215;256\u7684\u77e9\u9635<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>num_mask_tokens <span class=\"token operator\">&#061;<\/span> num_multimask_outputs <span class=\"token operator\">&#043;<\/span> <span class=\"token number\">1<\/span><br \/>\n    <span class=\"token comment\"># \u5b58\u50a8\u6240\u6709mask token\u7684\u5d4c\u5165\u5c42&#xff0c;\u5927\u5c0f\u4e3anum_mask_tokens\u00d7transformer_dim<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>mask_tokens <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Embedding<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>num_mask_tokens<span class=\"token punctuation\">,<\/span> transformer_dim<span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token comment\">#&#8212;&#8211; upscaled &#8212;&#8211;<\/span><br \/>\n    <span class=\"token comment\"># \u7528\u4e8e4\u500d\u4e0a\u91c7\u6837\u7684\u5e8f\u5217&#xff0c;\u5305\u542b\u4e24\u4e2a\u8f6c\u7f6e\u5377\u79ef\u5c42&#xff0c;\u6bcf\u4e2a\u4e0a\u91c7\u68372\u500d&#xff0c;\u4e2d\u95f4\u5939\u7740LayerNorm\u548c\u6fc0\u6d3b\u51fd\u6570<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>output_upscaling <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Sequential<span class=\"token punctuation\">(<\/span><br \/>\n        nn<span class=\"token punctuation\">.<\/span>ConvTranspose2d<span class=\"token punctuation\">(<\/span>transformer_dim<span class=\"token punctuation\">,<\/span> transformer_dim <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">,<\/span> kernel_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> stride<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span>     <span class=\"token comment\">#\u8f6c\u7f6e\u5377\u79ef \u4e0a\u91c7\u68372\u500d<\/span><br \/>\n        LayerNorm2d<span class=\"token punctuation\">(<\/span>transformer_dim <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        activation<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        nn<span class=\"token punctuation\">.<\/span>ConvTranspose2d<span class=\"token punctuation\">(<\/span>transformer_dim <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">,<\/span> transformer_dim <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">8<\/span><span class=\"token punctuation\">,<\/span> kernel_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> stride<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        activation<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># &#8212;&#8211; upscaled &#8212;&#8211;<\/span><\/p>\n<p>    <span class=\"token comment\"># \u591a\u5c42\u611f\u77e5\u673a&#xff08;MLP&#xff09;\u6a21\u5757<\/span><br \/>\n    <span class=\"token comment\">#  \u4e00\u4e2a\u6a21\u5757\u5217\u8868&#xff0c;\u5305\u542b\u4e86num_mask_tokens\u4e2aMLP&#xff0c;\u6bcf\u4e2aMLP\u7528\u4e8e\u5904\u7406\u4e0d\u540cmask\u7684\u8f93\u51fa<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>output_hypernetworks_mlps <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>ModuleList<span class=\"token punctuation\">(<\/span><br \/>\n        <span class=\"token punctuation\">[<\/span><br \/>\n            MLP<span class=\"token punctuation\">(<\/span>transformer_dim<span class=\"token punctuation\">,<\/span> transformer_dim<span class=\"token punctuation\">,<\/span> transformer_dim <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">8<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">)<\/span><br \/>\n            <span class=\"token keyword\">for<\/span> i <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>num_mask_tokens<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token punctuation\">]<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># &#8212;&#8211; MLP &#8212;&#8211;<\/span><\/p>\n<p>    <span class=\"token comment\"># &#8212;&#8211; MLP &#8212;&#8211;<\/span><br \/>\n    <span class=\"token comment\"># \u4e00\u4e2aMLP&#xff0c;\u7528\u4e8e\u9884\u6d4bIoU&#xff0c;\u8f93\u5165\u662ftransformer_dim&#xff0c;\u7ecf\u8fc7iou_head_hidden_dim\u7684\u9690\u85cf\u5c42&#xff0c;\u8f93\u51fa\u662fnum_mask_tokens<\/span><br \/>\n    self<span class=\"token punctuation\">.<\/span>iou_prediction_head <span class=\"token operator\">&#061;<\/span> MLP<span class=\"token punctuation\">(<\/span><br \/>\n        transformer_dim<span class=\"token punctuation\">,<\/span> iou_head_hidden_dim<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>num_mask_tokens<span class=\"token punctuation\">,<\/span> iou_head_depth<br \/>\n    <span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># &#8212;&#8211; MLP &#8212;&#8211;<\/span><\/p>\n<p>SAM\u6a21\u578bMask Decoder\u7f51\u7edc\u7ed3\u6784\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-282v1m4afrahp.png\" alt=\"\u5728\u8fd9\u91cc\u63d2\u5165\u56fe\u7247\u63cf\u8ff0\" width=\"700\" \/><\/p>\n<ul>\n<li>\n<p>spa_pro_emb(sparse embedding)\u3001iou_token\u3001mask_token\u5408\u5e76\u6210\u4e00\u4e2atokens&#xff0c;\u4f5c\u4e3apoint_embeddings\u3002<\/p>\n<\/li>\n<li>\n<p>spa_pro_emb&#xff1a; point\u3001bbox prompt\u5408\u5e76\u540e\u7684\u4ea7\u7269&#xff0c;\u4e00\u822c\u4e3aNxXx256\u3002<\/p>\n<\/li>\n<li>\n<p>iou_token&#xff1a;\u53ef\u5b66\u4e60\u53c2\u6570&#xff0c;\u5927\u5c0f\u4e3a1&#215;256\u3002<\/p>\n<\/li>\n<li>\n<p>mask_token&#xff1a;\u53ef\u5b66\u4e60\u53c2\u6570&#xff0c;\u5927\u5c0f\u4e3a4&#215;256\u3002<\/p>\n<\/li>\n<\/ul>\n<p>\u539f\u8bba\u6587\u4e2dMask Decoder\u6a21\u5757\u5404\u90e8\u5206\u7ed3\u6784\u793a\u610f\u56fe\u5982\u4e0b\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-28irxhuvmd50i.png\" alt=\"\u5728\u8fd9\u91cc\u63d2\u5165\u56fe\u7247\u63cf\u8ff0\" width=\"700\" \/><\/p>\n<p>Mask Decoder\u7f51\u7edc\u5728\u7279\u5f81\u63d0\u53d6\u4e2d\u7684\u57fa\u672c\u6b65\u9aa4\u5982\u4e0b&#xff1a;<\/p>\n<li>\n<p>transformer&#xff1a;\u5c06\u6765\u81ea\u7f16\u7801\u5668\u7684\u56fe\u50cf\u7279\u5f81\u4e0e\u989d\u5916\u7684\u63d0\u793a\u4fe1\u606f&#xff08;\u5982\u63a9\u7801\u63d0\u793a\u6216\u67e5\u8be2\u5411\u91cf&#xff09;\u878d\u5408&#xff0c;\u4ee5\u6355\u6349\u76ee\u6807\u533a\u57df\u7684\u4e0a\u4e0b\u6587\u4fe1\u606f\u3002<\/p>\n<\/li>\n<li>\n<p>upscaled&#xff1a;\u5bf9\u7c97\u7565mask src\u8fdb\u884c\u4e0a\u91c7\u6837&#xff0c;\u4f7f\u5176\u4e0e\u539f\u59cb\u56fe\u50cf\u5c3a\u5bf8\u76f8\u5339\u914d&#xff0c;\u4ee5\u4fbf\u8fdb\u884c\u66f4\u7cbe\u7ec6\u7684mask\u9884\u6d4b\u3002<\/p>\n<\/li>\n<li>\n<p>mask_MLP&#xff1a;\u901a\u8fc7\u4e00\u7cfb\u5217\u5168\u8fde\u63a5\u5c42&#xff0c;\u5bf9\u4e0a\u91c7\u6837\u540e\u7684\u7279\u5f81\u8fdb\u884c\u53d8\u6362&#xff0c;\u8ba1\u7b97\u51fa\u9488\u5bf9\u6bcf\u4e2a\u50cf\u7d20\u7684mask\u6982\u7387\u3002\u8fd9\u4e9b\u5c42\u53ef\u4ee5\u8bbe\u8ba1\u4e3a\u5b66\u4e60\u5982\u4f55\u4e3a\u6bcf\u4e2amask\u901a\u9053\u5206\u914d\u6743\u91cd&#xff0c;\u4ece\u800c\u751f\u6210\u6700\u7ec8\u7684mask\u8f93\u51fa\u3002<\/p>\n<\/li>\n<li>\n<p>iou_MLP&#xff1a;\u8bc4\u4f30\u751f\u6210\u7684mask\u4e0e\u771f\u5b9emask\u4e4b\u95f4\u7684\u91cd\u53e0\u7a0b\u5ea6&#xff0c;\u5373\u9884\u6d4bmask\u7684\u8d28\u91cf\u3002<\/p>\n<\/li>\n<p><span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span><br \/>\n    self<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># image encoder \u56fe\u50cf\u7279\u5f81<\/span><br \/>\n    image_embeddings<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n    <span class=\"token comment\"># 256x64x64<\/span><br \/>\n    image_pe<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u6807\u8bb0\u70b9\u548c\u6807\u8bb0\u6846\u7684\u5d4c\u5165\u7f16\u7801<\/span><br \/>\n    sparse_prompt_embeddings<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u8f93\u5165mask\u7684\u5d4c\u5165\u7f16\u7801<\/span><br \/>\n    dense_prompt_embeddings<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u662f\u5426\u8f93\u51fa\u591a\u4e2amask<\/span><br \/>\n    multimask_output<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">bool<\/span><span class=\"token punctuation\">,<\/span><br \/>\n<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> Tuple<span class=\"token punctuation\">[<\/span>torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u8fd9\u4e9b\u7279\u5f81\u878d\u5408&#xff0c;\u901a\u8fc7Transformer\u548c\u540e\u7eed\u7684\u4e0a\u91c7\u6837\u53caMLP\u5c42&#xff0c;\u751f\u6210\u63a9\u819c\u9884\u6d4b\u548cIoU\u5206\u6570<\/span><br \/>\n    masks<span class=\"token punctuation\">,<\/span> iou_pred <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>predict_masks<span class=\"token punctuation\">(<\/span><br \/>\n        image_embeddings<span class=\"token operator\">&#061;<\/span>image_embeddings<span class=\"token punctuation\">,<\/span><br \/>\n        image_pe<span class=\"token operator\">&#061;<\/span>image_pe<span class=\"token punctuation\">,<\/span><br \/>\n        sparse_prompt_embeddings<span class=\"token operator\">&#061;<\/span>sparse_prompt_embeddings<span class=\"token punctuation\">,<\/span><br \/>\n        dense_prompt_embeddings<span class=\"token operator\">&#061;<\/span>dense_prompt_embeddings<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5982\u679cmultimask_output\u4e3aTrue&#xff0c;\u8868\u793a\u9700\u8981\u8f93\u51fa\u591a\u4e2a\u63a9\u6a21&#xff0c;\u9009\u53d6\u7d22\u5f15\u4e3a1\u5230num_multimask_outputs\u7684\u6240\u6709\u63a9\u6a21<\/span><br \/>\n    <span class=\"token keyword\">if<\/span> multimask_output<span class=\"token punctuation\">:<\/span><br \/>\n        mask_slice <span class=\"token operator\">&#061;<\/span> <span class=\"token builtin\">slice<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5426\u5219&#xff0c;\u5982\u679cmultimask_output\u4e3aFalse&#xff0c;\u4ec5\u8f93\u51fa\u7b2c\u4e00\u4e2a\u63a9\u6a21&#xff08;\u901a\u5e38\u662f\u6700\u9ad8\u5f97\u5206\u7684\u63a9\u6a21&#xff09;<\/span><br \/>\n    <span class=\"token keyword\">else<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        mask_slice <span class=\"token operator\">&#061;<\/span> <span class=\"token builtin\">slice<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u6839\u636emultimask_output\u9009\u62e9\u540e\u7684\u63a9\u6a21&#xff0c;\u7ef4\u5ea6\u8c03\u6574\u4e3a(batch_size, num_selected_masks, height, width)<\/span><br \/>\n    masks <span class=\"token operator\">&#061;<\/span> masks<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> mask_slice<span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span><br \/>\n    <span class=\"token comment\"># \u6839\u636emultimask_output\u9009\u62e9\u540e\u7684IoU\u9884\u6d4b&#xff0c;\u7ef4\u5ea6\u8c03\u6574\u4e3a(batch_size, num_selected_masks)<\/span><br \/>\n    iou_pred <span class=\"token operator\">&#061;<\/span> iou_pred<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> mask_slice<span class=\"token punctuation\">]<\/span><br \/>\n    <span class=\"token keyword\">return<\/span> masks<span class=\"token punctuation\">,<\/span> iou_pred<\/p>\n<p><span class=\"token keyword\">def<\/span> <span class=\"token function\">predict_masks<\/span><span class=\"token punctuation\">(<\/span><br \/>\n    self<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># image embedding&#xff1a; \u662fimage encoder\u7684\u8f93\u51fa&#xff0c;\u5927\u5c0f\u4e3a\u4e3a1x256x64x64<\/span><br \/>\n    image_embeddings<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># image_pe\u4f4d\u7f6e\u7f16\u7801\u4e5f\u62d3\u5c55\u6210Nx256x64x64\u7684\u77e9\u9635<\/span><br \/>\n    image_pe<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    sparse_prompt_embeddings<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n    dense_prompt_embeddings<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span><br \/>\n<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> Tuple<span class=\"token punctuation\">[<\/span>torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># \u9996\u5148\u5c06iou token\u548cmask token \u62fc\u63a5\u5f97\u5230\u4e00\u4e2a5&#215;256\u7684\u77e9\u9635&#xff0c;\u518d\u5c06\u5176\u62d3\u5c55\u5230\u4e0esparse embedding\u4e00\u4e2a\u7ef4\u5ea6Nx5x256<\/span><br \/>\n    <span class=\"token comment\"># 1,E and 4,E &#8211;&gt; 5,E<\/span><br \/>\n    output_tokens <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>cat<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>self<span class=\"token punctuation\">.<\/span>iou_token<span class=\"token punctuation\">.<\/span>weight<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>mask_tokens<span class=\"token punctuation\">.<\/span>weight<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u518d\u5c06\u62d3\u5c55\u540e\u7684\u77e9\u9635\u4e0esparse embedding\u62fc\u63a5\u5f97\u5230tokens&#xff0c;\u5176\u5927\u5c0fNx(5&#043;X)x256<\/span><br \/>\n    <span class=\"token comment\"># 5,E &#8211;&gt; B,5,E<\/span><br \/>\n    output_tokens <span class=\"token operator\">&#061;<\/span> output_tokens<span class=\"token punctuation\">.<\/span>unsqueeze<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>expand<span class=\"token punctuation\">(<\/span>sparse_prompt_embeddings<span class=\"token punctuation\">.<\/span>size<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u518d\u4e0e\u7a00\u758f\u77e9\u9635\u62fc\u63a5&#xff0c;\u5047\u8bbe\u7a00\u758f\u77e9\u9635\u53ea\u6709point\u4e3aNx2x256&#xff0c;\u62fc\u63a5\u4e4b\u540e\u5219\u4e3aNx(5&#043;2)x256<\/span><br \/>\n    <span class=\"token comment\"># B,5,E and B,N,E &#8211;&gt;B,5&#043;N,E       N\u662f\u70b9\u7684\u4e2a\u6570(\u6807\u8bb0\u70b9\u548c\u6807\u8bb0\u6846\u7684\u70b9)<\/span><br \/>\n    tokens <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>cat<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">(<\/span>output_tokens<span class=\"token punctuation\">,<\/span> sparse_prompt_embeddings<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token comment\"># \u5c06image embedding(1x256x64x64)\u62d3\u5c55\u6210\u7a20\u5bc6prompt\u7684\u7ef4\u5ea6&#xff1a;Nx256x64x64<\/span><br \/>\n    <span class=\"token comment\"># B,C,H,W<\/span><br \/>\n    src <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>repeat_interleave<span class=\"token punctuation\">(<\/span>image_embeddings<span class=\"token punctuation\">,<\/span> tokens<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\">#\u5c06\u62d3\u5c55\u540e\u7684image embedding\u76f4\u63a5\u4e0e\u7a20\u5bc6prompt\u76f8\u52a0&#xff1a;Nx256x64x64<\/span><br \/>\n    <span class=\"token comment\"># B,C,H,W &#043; 1,C,H,W &#8212;&gt; B,C,H,W<\/span><br \/>\n    src <span class=\"token operator\">&#061;<\/span> src <span class=\"token operator\">&#043;<\/span> dense_prompt_embeddings<br \/>\n    <span class=\"token comment\"># # \u5c06256x64x64\u7684\u4f4d\u7f6e\u7f16\u7801,\u62d3\u5c55\u6210Nx256x64x64<\/span><br \/>\n    <span class=\"token comment\"># 1,C,H,W&#8212;&gt; B,C,H,W<\/span><br \/>\n    pos_src <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>repeat_interleave<span class=\"token punctuation\">(<\/span>image_pe<span class=\"token punctuation\">,<\/span> tokens<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    b<span class=\"token punctuation\">,<\/span> c<span class=\"token punctuation\">,<\/span> h<span class=\"token punctuation\">,<\/span> w <span class=\"token operator\">&#061;<\/span> src<span class=\"token punctuation\">.<\/span>shape<\/p>\n<p>    <span class=\"token comment\"># &#8212;&#8211; transformer &#8212;&#8211;<\/span><br \/>\n    <span class=\"token comment\"># Run the transformer&#xff1a;\u8fd9\u91cc\u4f7f\u7528\u7684TwoWayTransformer&#xff0c;\u6709\u5fc5\u8981\u5bf9\u8f93\u5165\u518d\u8bf4\u660e\u4e00\u4e0b<\/span><br \/>\n    <span class=\"token comment\"># src&#xff1a;image_bedding &#043; dense_prompt&#xff08;mask&#xff09;,Nx256x64x64<\/span><br \/>\n    <span class=\"token comment\"># pos_src: \u4f4d\u7f6e\u7f16\u7801,Nx256x64x64<\/span><br \/>\n    <span class=\"token comment\"># tokens: iou_tokens &#043; mask_tokens &#043; sparse_prompt&#xff08;point\/bbox&#xff09;,Nx(5&#043;x)x256<\/span><br \/>\n    <span class=\"token comment\"># B,N,C<\/span><br \/>\n    hs<span class=\"token punctuation\">,<\/span> src <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>transformer<span class=\"token punctuation\">(<\/span>src<span class=\"token punctuation\">,<\/span> pos_src<span class=\"token punctuation\">,<\/span> tokens<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># &#8212;&#8211; transformer &#8212;&#8211;<\/span><br \/>\n    <span class=\"token comment\"># # \u540e\u5904\u7406<\/span><br \/>\n    iou_token_out <span class=\"token operator\">&#061;<\/span> hs<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span><br \/>\n    mask_tokens_out <span class=\"token operator\">&#061;<\/span> hs<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">:<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span> <span class=\"token operator\">&#043;<\/span> self<span class=\"token punctuation\">.<\/span>num_mask_tokens<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span><\/p>\n<p>    <span class=\"token comment\"># \u901a\u8fc7\u4e0a\u91c7\u6837\u5c42\u5c06Transformer\u8f93\u51fa\u7684\u63a9\u6a21\u90e8\u5206\u6062\u590d\u5230(batch_size, channels, height, width)\u7684\u5f62\u72b6<\/span><br \/>\n    <span class=\"token comment\"># B,N,C&#8211;&gt;B,C,H,W<\/span><br \/>\n    src <span class=\"token operator\">&#061;<\/span> src<span class=\"token punctuation\">.<\/span>transpose<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>b<span class=\"token punctuation\">,<\/span> c<span class=\"token punctuation\">,<\/span> h<span class=\"token punctuation\">,<\/span> w<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># &#8212;&#8211; upscaled &#8212;&#8211;<\/span><br \/>\n    <span class=\"token comment\"># 4\u500d\u4e0a\u91c7\u6837<\/span><br \/>\n    upscaled_embedding <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>output_upscaling<span class=\"token punctuation\">(<\/span>src<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># &#8212;&#8211; upscaled &#8212;&#8211;<\/span><\/p>\n<p>    <span class=\"token comment\"># \u5bf9\u6bcf\u4e2amask token&#xff0c;\u901a\u8fc7\u5176\u5bf9\u5e94\u7684MLP\u5f97\u5230\u4e00\u4e2a\u6743\u91cd\u5f20\u91cf&#xff0c;\u4f7f\u7528\u8fd9\u4e9b\u6743\u91cd\u4e0e\u4e0a\u91c7\u6837\u540e\u7684\u7279\u5f81\u5f20\u91cf\u8fdb\u884c\u70b9\u4e58&#xff0c;\u5f97\u5230\u63a9\u6a21\u9884\u6d4b(batch_size, num_mask_tokens, height, width)<\/span><br \/>\n    hyper_in_list<span class=\"token punctuation\">:<\/span> List<span class=\"token punctuation\">[<\/span>torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">]<\/span><\/p>\n<p>    <span class=\"token comment\"># &#8212;&#8211; mlp &#8212;&#8211;<\/span><br \/>\n    <span class=\"token keyword\">for<\/span> i <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>num_mask_tokens<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># mask_tokens_out[:, i, :]: B,1,C<\/span><br \/>\n        <span class=\"token comment\"># output_hypernetworks_mlps: B,1,c<\/span><br \/>\n        hyper_in_list<span class=\"token punctuation\">.<\/span>append<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>output_hypernetworks_mlps<span class=\"token punctuation\">[<\/span>i<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">(<\/span>mask_tokens_out<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> i<span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># B,n,c<\/span><br \/>\n    hyper_in <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>stack<span class=\"token punctuation\">(<\/span>hyper_in_list<span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># &#8212;&#8211; mlp &#8212;&#8211;<\/span><\/p>\n<p>    b<span class=\"token punctuation\">,<\/span> c<span class=\"token punctuation\">,<\/span> h<span class=\"token punctuation\">,<\/span> w <span class=\"token operator\">&#061;<\/span> upscaled_embedding<span class=\"token punctuation\">.<\/span>shape<br \/>\n    <span class=\"token comment\"># B,n,c \u00d7 B,c,N&#8211;&gt;B,n,h,w<\/span><br \/>\n    masks <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span>hyper_in &#064; upscaled_embedding<span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>b<span class=\"token punctuation\">,<\/span> c<span class=\"token punctuation\">,<\/span> h <span class=\"token operator\">*<\/span> w<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>b<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> h<span class=\"token punctuation\">,<\/span> w<span class=\"token punctuation\">)<\/span><\/p>\n<p>    <span class=\"token comment\"># &#8212;&#8211; mlp &#8212;&#8211;<\/span><br \/>\n    <span class=\"token comment\"># \u901a\u8fc7IoU\u9884\u6d4b\u5934&#xff08;MLP&#xff09;\u5bf9IoU token\u7684\u8f93\u51fa\u8fdb\u884c\u5904\u7406&#xff0c;\u5f97\u5230(batch_size, num_mask_tokens)\u7684IoU\u5206\u6570<\/span><br \/>\n    <span class=\"token comment\"># iou_token_out: B,1,n<\/span><br \/>\n    iou_pred <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>iou_prediction_head<span class=\"token punctuation\">(<\/span>iou_token_out<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># &#8212;&#8211; mlp &#8212;&#8211;<\/span><br \/>\n    <span class=\"token comment\"># \u8fd4\u56de\u9884\u6d4b\u7684\u63a9\u6a21\u548cIoU\u5206\u6570<\/span><br \/>\n    <span class=\"token comment\"># masks: B,n,h,w<\/span><br \/>\n    <span class=\"token comment\"># iou_pred: B,1,n<\/span><br \/>\n    <span class=\"token keyword\">return<\/span> masks<span class=\"token punctuation\">,<\/span> iou_pred<\/p>\n<h5>1. transformer<\/h5>\n<p>Mask Decoder\u7531\u591a\u4e2a\u91cd\u590d\u5806\u53e0TwoWayAttention Block\u548c1\u4e2aMulti-Head Attention\u7ec4\u6210\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-28zuezlu4n1jl.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p>\u300cTwoWayAttention Block\u300d<\/p>\n<p>TwoWayAttention Block\u7531LayerNorm \u3001Multi-Head Attention\u548cMLP\u6784\u6210\u3002\u6240\u8c13\u7684TwoWay&#xff1a;\u5373\u662f\u4e24\u8f6e\u6b21\u5faa\u73af&#xff0c;\u7b2c\u4e00\u6b21point_embedding\u81ea\u6ce8\u610f&#xff0c;\u7b2c\u4e8c\u6b21\u5219\u52a0\u4e0a\u4e0a\u4e00\u8f6e\u8f93\u51fa\u7684queries\u8fdb\u884cattention\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-28atbxfk1xtiq.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p>\u539f\u8bba\u6587\u4e2dTwoWayAttention\u90e8\u5206\u793a\u610f\u56fe\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-282c2osoq23md.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p><span class=\"token keyword\">class<\/span> <span class=\"token class-name\">TwoWayAttentionBlock<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><br \/>\n        self<span class=\"token punctuation\">,<\/span><br \/>\n        embedding_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>         <span class=\"token comment\"># \u8f93\u5165\u7279\u5f81\u7ef4\u5ea6<\/span><br \/>\n        num_heads<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>             <span class=\"token comment\"># \u6ce8\u610f\u529b\u5934\u7684\u6570\u91cf&#xff0c;\u51b3\u5b9a\u4e86\u6ce8\u610f\u529b\u673a\u5236\u7684\u5e76\u884c\u5ea6<\/span><br \/>\n        mlp_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">2048<\/span><span class=\"token punctuation\">,<\/span>        <span class=\"token comment\"># MLP&#xff08;\u591a\u5c42\u611f\u77e5\u673a&#xff09;\u4e2d\u95f4\u5c42\u7684\u7ef4\u5ea6&#xff0c;\u7528\u4e8e\u7279\u5f81\u53d8\u6362\u548c\u975e\u7ebf\u6027\u589e\u5f3a<\/span><br \/>\n        activation<span class=\"token punctuation\">:<\/span> Type<span class=\"token punctuation\">[<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>ReLU<span class=\"token punctuation\">,<\/span>      <span class=\"token comment\"># \u6fc0\u6d3b\u51fd\u6570\u7c7b\u578b&#xff0c;\u9ed8\u8ba4\u4e3aReLU<\/span><br \/>\n        attention_downsample_rate<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span>         <span class=\"token comment\"># \u4e0b\u91c7\u6837\u6bd4\u7387<\/span><br \/>\n        <span class=\"token comment\"># \u662f\u5426\u5728\u7b2c\u4e00\u5c42\u81ea\u6ce8\u610f\u529b\u4e2d\u8df3\u8fc7\u4f4d\u7f6e\u7f16\u7801\u7684\u6b8b\u5dee\u8fde\u63a5<\/span><br \/>\n        skip_first_layer_pe<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">bool<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">False<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u81ea\u6ce8\u610f\u529b\u6a21\u5757&#xff0c;\u7528\u4e8e\u589e\u5f3aqueries\u5185\u90e8\u7684\u4fe1\u606f\u4ea4\u4e92<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>self_attn <span class=\"token operator\">&#061;<\/span> Attention<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># norm1\/2\/3\/4: LayerNorm\u5c42&#xff0c;\u7528\u4e8e\u7a33\u5b9a\u8bad\u7ec3\u548c\u52a0\u901f\u6536\u655b<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>norm1 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>LayerNorm<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># cross_attn_token_to_image\u548ccross_attn_image_to_token: \u4ea4\u53c9\u6ce8\u610f\u529b\u6a21\u5757&#xff0c;\u5206\u522b\u8ba9\u6807\u8bb0\u70b9\u7279\u5f81\u5173\u6ce8\u56fe\u50cf\u7279\u5f81&#xff0c;\u4ee5\u53ca\u56fe\u50cf\u7279\u5f81\u53cd\u8fc7\u6765\u5173\u6ce8\u6807\u8bb0\u70b9\u7279\u5f81<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>cross_attn_token_to_image <span class=\"token operator\">&#061;<\/span> Attention<span class=\"token punctuation\">(<\/span><br \/>\n            embedding_dim<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> downsample_rate<span class=\"token operator\">&#061;<\/span>attention_downsample_rate<br \/>\n        <span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>norm2 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>LayerNorm<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># mlp: \u591a\u5c42\u611f\u77e5\u673a\u6a21\u5757&#xff0c;\u589e\u52a0\u6a21\u578b\u7684\u8868\u8fbe\u80fd\u529b<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>mlp <span class=\"token operator\">&#061;<\/span> MLPBlock<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token punctuation\">,<\/span> mlp_dim<span class=\"token punctuation\">,<\/span> activation<span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>norm3 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>LayerNorm<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token punctuation\">)<\/span><\/p>\n<p>        self<span class=\"token punctuation\">.<\/span>norm4 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>LayerNorm<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>cross_attn_image_to_token <span class=\"token operator\">&#061;<\/span> Attention<span class=\"token punctuation\">(<\/span><br \/>\n            embedding_dim<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> downsample_rate<span class=\"token operator\">&#061;<\/span>attention_downsample_rate<br \/>\n        <span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>skip_first_layer_pe <span class=\"token operator\">&#061;<\/span> skip_first_layer_pe<br \/>\n    <span class=\"token comment\"># \u524d\u5411\u4f20\u64ad<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span><br \/>\n        self<span class=\"token punctuation\">,<\/span> queries<span class=\"token punctuation\">:<\/span> Tensor<span class=\"token punctuation\">,<\/span> keys<span class=\"token punctuation\">:<\/span> Tensor<span class=\"token punctuation\">,<\/span> query_pe<span class=\"token punctuation\">:<\/span> Tensor<span class=\"token punctuation\">,<\/span> key_pe<span class=\"token punctuation\">:<\/span> Tensor<br \/>\n    <span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> Tuple<span class=\"token punctuation\">[<\/span>Tensor<span class=\"token punctuation\">,<\/span> Tensor<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">:<\/span><\/p>\n<p>        <span class=\"token comment\"># queries&#xff1a;\u6807\u8bb0\u70b9\u7f16\u7801\u76f8\u5173(\u539f\u59cb\u6807\u8bb0\u70b9\u7f16\u7801\u7ecf\u8fc7\u4e00\u7cfb\u5217\u7279\u5f81\u63d0\u53d6)<\/span><br \/>\n        <span class=\"token comment\"># keys&#xff1a;\u539f\u59cb\u56fe\u50cf\u7f16\u7801\u76f8\u5173(\u539f\u59cb\u56fe\u50cf\u7f16\u7801\u7ecf\u8fc7\u4e00\u7cfb\u5217\u7279\u5f81\u63d0\u53d6)<\/span><br \/>\n        <span class=\"token comment\"># query_pe&#xff1a;\u539f\u59cb\u6807\u8bb0\u70b9\u7f16\u7801<\/span><br \/>\n        <span class=\"token comment\"># key_pe&#xff1a;\u539f\u59cb\u56fe\u50cf\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n        <span class=\"token comment\"># \u7b2c\u4e00\u8f6e\u672c\u8eabqueries&#061;&#061;query_pe\u6ca1\u6bd4\u8f83\u518d&#034;\u6b8b\u5dee&#034;<\/span><\/p>\n<p>        <span class=\"token comment\"># \u9996\u5148\u5bf9queries\u5e94\u7528\u81ea\u6ce8\u610f\u529b&#xff0c;\u82e5skip_first_layer_pe&#061;True&#xff0c;\u76f4\u63a5\u4f7f\u7528queries\u8fdb\u884c\u81ea\u6ce8\u610f\u529b\u8ba1\u7b97&#xff1b;\u5426\u5219&#xff0c;\u5c06queries\u4e0equery_pe\u76f8\u52a0\u540e\u8fdb\u884c\u81ea\u6ce8\u610f\u529b\u8ba1\u7b97&#xff0c;\u5e76\u6b8b\u5dee\u8fde\u63a5\u56dequeries&#xff0c;\u4e4b\u540e\u8fdb\u884cLayerNorm<\/span><br \/>\n        <span class=\"token keyword\">if<\/span> self<span class=\"token punctuation\">.<\/span>skip_first_layer_pe<span class=\"token punctuation\">:<\/span><br \/>\n            queries <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>self_attn<span class=\"token punctuation\">(<\/span>q<span class=\"token operator\">&#061;<\/span>queries<span class=\"token punctuation\">,<\/span> k<span class=\"token operator\">&#061;<\/span>queries<span class=\"token punctuation\">,<\/span> v<span class=\"token operator\">&#061;<\/span>queries<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">else<\/span><span class=\"token punctuation\">:<\/span><br \/>\n            q <span class=\"token operator\">&#061;<\/span> queries <span class=\"token operator\">&#043;<\/span> query_pe<br \/>\n            attn_out <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>self_attn<span class=\"token punctuation\">(<\/span>q<span class=\"token operator\">&#061;<\/span>q<span class=\"token punctuation\">,<\/span> k<span class=\"token operator\">&#061;<\/span>q<span class=\"token punctuation\">,<\/span> v<span class=\"token operator\">&#061;<\/span>queries<span class=\"token punctuation\">)<\/span><br \/>\n            queries <span class=\"token operator\">&#061;<\/span> queries <span class=\"token operator\">&#043;<\/span> attn_out<br \/>\n        queries <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>norm1<span class=\"token punctuation\">(<\/span>queries<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u8c03\u6574queries\u548ckeys&#xff08;\u56fe\u50cf\u7279\u5f81&#xff09;\u52a0\u4e0a\u5404\u81ea\u7684\u4f4d\u7f6e\u7f16\u7801&#xff0c;\u7136\u540e\u901a\u8fc7cross_attn_token_to_image\u4ea4\u53c9\u6ce8\u610f\u529b\u5c42&#xff0c;\u4f7f\u6807\u8bb0\u70b9\u7279\u5f81\u5173\u6ce8\u56fe\u50cf\u7279\u5f81&#xff0c;\u7ed3\u679c\u4e0e\u539f\u59cbqueries\u6b8b\u5dee\u8fde\u63a5\u5e76\u8fdb\u884cLayerNorm<\/span><br \/>\n        q <span class=\"token operator\">&#061;<\/span> queries <span class=\"token operator\">&#043;<\/span> query_pe<br \/>\n        k <span class=\"token operator\">&#061;<\/span> keys <span class=\"token operator\">&#043;<\/span> key_pe<br \/>\n        attn_out <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>cross_attn_token_to_image<span class=\"token punctuation\">(<\/span>q<span class=\"token operator\">&#061;<\/span>q<span class=\"token punctuation\">,<\/span> k<span class=\"token operator\">&#061;<\/span>k<span class=\"token punctuation\">,<\/span> v<span class=\"token operator\">&#061;<\/span>keys<span class=\"token punctuation\">)<\/span><br \/>\n        queries <span class=\"token operator\">&#061;<\/span> queries <span class=\"token operator\">&#043;<\/span> attn_out<br \/>\n        queries <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>norm2<span class=\"token punctuation\">(<\/span>queries<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># MLP block&#xff1a;\u5c06\u66f4\u65b0\u540e\u7684queries\u901a\u8fc7MLP\u6a21\u5757\u8fdb\u884c\u975e\u7ebf\u6027\u53d8\u6362&#xff0c;\u7ed3\u679c\u4e0e\u539fqueries\u6b8b\u5dee\u8fde\u63a5\u5e76\u8fdb\u884cLayerNorm<\/span><br \/>\n        mlp_out <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>mlp<span class=\"token punctuation\">(<\/span>queries<span class=\"token punctuation\">)<\/span><br \/>\n        queries <span class=\"token operator\">&#061;<\/span> queries <span class=\"token operator\">&#043;<\/span> mlp_out<br \/>\n        queries <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>norm3<span class=\"token punctuation\">(<\/span>queries<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u4ea4\u53c9\u6ce8\u610f\u529b&#xff08;\u56fe\u50cf\u5230\u6807\u8bb0\u70b9&#xff09;&#xff1a;\u518d\u6b21\u8c03\u6574queries\u548ckeys\u52a0\u4e0a\u4f4d\u7f6e\u7f16\u7801&#xff0c;\u4f46\u8fd9\u6b21\u901a\u8fc7cross_attn_image_to_token\u8ba9\u56fe\u50cf\u7279\u5f81\u5173\u6ce8\u6807\u8bb0\u70b9\u7279\u5f81&#xff0c;\u66f4\u65b0\u540e\u7684keys\u4e0e\u539f\u59cbkeys\u6b8b\u5dee\u8fde\u63a5\u5e76\u8fdb\u884cLayerNorm<\/span><br \/>\n        q <span class=\"token operator\">&#061;<\/span> queries <span class=\"token operator\">&#043;<\/span> query_pe<br \/>\n        k <span class=\"token operator\">&#061;<\/span> keys <span class=\"token operator\">&#043;<\/span> key_pe<br \/>\n        attn_out <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>cross_attn_image_to_token<span class=\"token punctuation\">(<\/span>q<span class=\"token operator\">&#061;<\/span>k<span class=\"token punctuation\">,<\/span> k<span class=\"token operator\">&#061;<\/span>q<span class=\"token punctuation\">,<\/span> v<span class=\"token operator\">&#061;<\/span>queries<span class=\"token punctuation\">)<\/span><br \/>\n        keys <span class=\"token operator\">&#061;<\/span> keys <span class=\"token operator\">&#043;<\/span> attn_out<br \/>\n        keys <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>norm4<span class=\"token punctuation\">(<\/span>keys<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> queries<span class=\"token punctuation\">,<\/span> keys<\/p>\n<p>\u300cAttention\u300d<\/p>\n<p>Mask Decoder\u7684Attention\u4e0eViT\u7684Attention\u6709\u4e9b\u7ec6\u5fae\u7684\u4e0d\u540c&#xff1a;<\/p>\n<ul>\n<li>\n<p>Mask Decoder\u7684Attention\u662f3\u4e2aFC\u5c42\u5206\u522b\u63a5\u53d73\u4e2a\u8f93\u5165\u83b7\u5f97q\u3001k\u548cv\u3002<\/p>\n<\/li>\n<li>\n<p>ViT\u7684Attention\u662f1\u4e2aFC\u5c42\u63a5\u53d71\u4e2a\u8f93\u5165\u540e\u5c06\u7ed3\u679c\u5747\u62c6\u5206\u83b7\u5f97q\u3001k\u548cv\u3002<\/p>\n<\/li>\n<\/ul>\n<p>\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-28q4u4zsyyoqd.png\" alt=\"\u56fe\u7247\" width=\"700\" \/><\/p>\n<p>\u539f\u8bba\u6587\u4e2dAttention\u90e8\u5206\u793a\u610f\u56fe\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-281yaal3xvabo.png\" alt=\"\u56fe\u7247\" width=\"600\" \/><\/p>\n<p><span class=\"token keyword\">class<\/span> <span class=\"token class-name\">Attention<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><br \/>\n        self<span class=\"token punctuation\">,<\/span><br \/>\n        embedding_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>         <span class=\"token comment\"># \u8f93\u5165\u7279\u5f81\u7684\u7ef4\u5ea6<\/span><br \/>\n        num_heads<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>             <span class=\"token comment\"># attention\u7684head\u6570<\/span><br \/>\n        downsample_rate<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span>   <span class=\"token comment\"># \u4e0b\u91c7\u6837<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>embedding_dim <span class=\"token operator\">&#061;<\/span> embedding_dim<br \/>\n        <span class=\"token comment\"># \u5185\u90e8\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>internal_dim <span class=\"token operator\">&#061;<\/span> embedding_dim <span class=\"token operator\">\/\/<\/span> downsample_rate<br \/>\n        self<span class=\"token punctuation\">.<\/span>num_heads <span class=\"token operator\">&#061;<\/span> num_heads<br \/>\n        <span class=\"token keyword\">assert<\/span> self<span class=\"token punctuation\">.<\/span>internal_dim <span class=\"token operator\">%<\/span> num_heads <span class=\"token operator\">&#061;&#061;<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token string\">&#034;num_heads must divide embedding_dim.&#034;<\/span><br \/>\n        <span class=\"token comment\"># \u56db\u4e2a\u7ebf\u6027\u5c42&#xff08;\u5168\u8fde\u63a5\u5c42&#xff09;&#xff1a;\u7528\u4e8e\u751f\u6210query\u5411\u91cf\u3001key\u5411\u91cf\u3001value\u5411\u91cf<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>q_proj <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>internal_dim<span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>k_proj <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>internal_dim<span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>v_proj <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>internal_dim<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u7528\u4e8e\u5c06\u6ce8\u610f\u529b\u673a\u5236\u540e\u7684\u8f93\u51fa\u6295\u5f71\u56de\u539f\u59cb\u7684\u7279\u5f81\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>out_proj <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>internal_dim<span class=\"token punctuation\">,<\/span> embedding_dim<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u8f93\u5165\u5f20\u91cf\u5206\u89e3\u4e3a\u591a\u5934\u6ce8\u610f\u529b\u6240\u9700\u7684\u5f62\u72b6<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">_separate_heads<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">:<\/span> Tensor<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> Tensor<span class=\"token punctuation\">:<\/span><br \/>\n        b<span class=\"token punctuation\">,<\/span> n<span class=\"token punctuation\">,<\/span> c <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>shape<br \/>\n        x <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>reshape<span class=\"token punctuation\">(<\/span>b<span class=\"token punctuation\">,<\/span> n<span class=\"token punctuation\">,<\/span> num_heads<span class=\"token punctuation\">,<\/span> c <span class=\"token operator\">\/\/<\/span> num_heads<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> x<span class=\"token punctuation\">.<\/span>transpose<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span>  <span class=\"token comment\"># B x N_heads x N_tokens x C_per_head<\/span><br \/>\n    <span class=\"token comment\"># \u5728\u6ce8\u610f\u529b\u8ba1\u7b97\u540e\u91cd\u65b0\u7ec4\u5408\u8fd9\u4e9b\u5934\u90e8<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">_recombine_heads<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">:<\/span> Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> Tensor<span class=\"token punctuation\">:<\/span><br \/>\n        b<span class=\"token punctuation\">,<\/span> n_heads<span class=\"token punctuation\">,<\/span> n_tokens<span class=\"token punctuation\">,<\/span> c_per_head <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>shape<br \/>\n        x <span class=\"token operator\">&#061;<\/span> x<span class=\"token punctuation\">.<\/span>transpose<span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> x<span class=\"token punctuation\">.<\/span>reshape<span class=\"token punctuation\">(<\/span>b<span class=\"token punctuation\">,<\/span> n_tokens<span class=\"token punctuation\">,<\/span> n_heads <span class=\"token operator\">*<\/span> c_per_head<span class=\"token punctuation\">)<\/span>  <span class=\"token comment\"># B x N_tokens x C<\/span><\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> q<span class=\"token punctuation\">:<\/span> Tensor<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">:<\/span> Tensor<span class=\"token punctuation\">,<\/span> v<span class=\"token punctuation\">:<\/span> Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> Tensor<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u8f93\u5165\u6295\u5f71&#xff1a;\u5206\u522b\u4f7f\u7528q_proj\u3001k_proj\u548cv_proj\u5bf9query\u3001key\u548cvalue\u8fdb\u884c\u7ebf\u6027\u53d8\u6362<\/span><br \/>\n        q <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>q_proj<span class=\"token punctuation\">(<\/span>q<span class=\"token punctuation\">)<\/span><br \/>\n        k <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>k_proj<span class=\"token punctuation\">(<\/span>k<span class=\"token punctuation\">)<\/span><br \/>\n        v <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>v_proj<span class=\"token punctuation\">(<\/span>v<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u5206\u79bb\u5934\u90e8&#xff1a;\u5c06\u53d8\u6362\u540e\u7684query\u3001key\u548cvalue\u5f20\u91cf\u6309\u7167num_heads\u8fdb\u884c\u91cd\u5851&#xff0c;\u4ee5\u4fbf\u8fdb\u884c\u591a\u5934\u6ce8\u610f\u529b\u8ba1\u7b97<\/span><br \/>\n        <span class=\"token comment\"># B,N_heads,N_tokens,C_per_head<\/span><br \/>\n        q <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>_separate_heads<span class=\"token punctuation\">(<\/span>q<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>num_heads<span class=\"token punctuation\">)<\/span><br \/>\n        k <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>_separate_heads<span class=\"token punctuation\">(<\/span>k<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>num_heads<span class=\"token punctuation\">)<\/span><br \/>\n        v <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>_separate_heads<span class=\"token punctuation\">(<\/span>v<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>num_heads<span class=\"token punctuation\">)<\/span><\/p>\n<p>        <span class=\"token comment\"># \u6ce8\u610f\u529b\u8ba1\u7b97&#xff1a;<\/span><br \/>\n        <span class=\"token comment\"># \u8ba1\u7b97query\u548ckey\u7684\u70b9\u79ef&#xff0c;\u7136\u540e\u9664\u4ee5c_per_head\u7684\u5e73\u65b9\u6839\u8fdb\u884c\u5f52\u4e00\u5316&#xff0c;\u4ee5\u9632\u6b62\u6570\u503c\u8fc7\u5927<\/span><br \/>\n        _<span class=\"token punctuation\">,<\/span> _<span class=\"token punctuation\">,<\/span> _<span class=\"token punctuation\">,<\/span> c_per_head <span class=\"token operator\">&#061;<\/span> q<span class=\"token punctuation\">.<\/span>shape<br \/>\n        attn <span class=\"token operator\">&#061;<\/span> q &#064; k<span class=\"token punctuation\">.<\/span>permute<span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span>  <span class=\"token comment\"># B,N_heads,N_tokens,C_per_head<\/span><br \/>\n        <span class=\"token comment\"># \u5f52\u4e00\u5316Scale<\/span><br \/>\n        attn <span class=\"token operator\">&#061;<\/span> attn <span class=\"token operator\">\/<\/span> math<span class=\"token punctuation\">.<\/span>sqrt<span class=\"token punctuation\">(<\/span>c_per_head<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5e94\u7528softmax\u51fd\u6570\u5f97\u5230\u6ce8\u610f\u529b\u6743\u91cd<\/span><br \/>\n        attn <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>softmax<span class=\"token punctuation\">(<\/span>attn<span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u4f7f\u7528\u6ce8\u610f\u529b\u6743\u91cd\u5bf9value\u8fdb\u884c\u52a0\u6743\u6c42\u548c&#xff0c;\u5f97\u5230\u6ce8\u610f\u529b\u8f93\u51fa<\/span><br \/>\n        out <span class=\"token operator\">&#061;<\/span> attn &#064; v<br \/>\n        <span class=\"token comment\"># # B,N_tokens,C<\/span><br \/>\n        <span class=\"token comment\"># \u91cd\u65b0\u7ec4\u5408\u5934\u90e8&#xff1a;\u5c06\u591a\u5934\u6ce8\u610f\u529b\u8f93\u51fa\u5408\u5e76\u56de\u539f\u59cb\u7684\u7279\u5f81\u7ef4\u5ea6\u3002<\/span><br \/>\n        out <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>_recombine_heads<span class=\"token punctuation\">(<\/span>out<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u8f93\u51fa\u6295\u5f71&#xff1a;\u6700\u540e&#xff0c;\u901a\u8fc7out_proj\u5c06\u8f93\u51fa\u6295\u5f71\u56de\u539f\u59cb\u7684embedding_dim<\/span><br \/>\n        out <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>out_proj<span class=\"token punctuation\">(<\/span>out<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> out<\/p>\n<p>\u300ctransformer_MLP\u300d<\/p>\n<p>transformer\u4e2dMLP\u7684\u7ed3\u6784\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-28aeabmpd3oad.png\" alt=\"\u56fe\u7247\" width=\"200\" \/><\/p>\n<p><span class=\"token comment\"># MLPBlock\u7c7b\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u591a\u5c42\u611f\u77e5\u673a&#xff08;MLP&#xff09;\u6a21\u5757&#xff0c;\u7531\u4e24\u4e2a\u5168\u8fde\u63a5\u5c42&#xff08;Linear&#xff09;\u548c\u4e00\u4e2a\u6fc0\u6d3b\u51fd\u6570\u7ec4\u6210<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">MLPBlock<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><br \/>\n        self<span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u8f93\u5165\u7684\u7ef4\u5ea6&#xff0c;\u901a\u5e38\u662f\u7279\u5f81\u5411\u91cf\u7684\u957f\u5ea6<\/span><br \/>\n        embedding_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># MLP\u4e2d\u95f4\u5c42\u7684\u5bbd\u5ea6&#xff0c;\u53ef\u4ee5\u8bbe\u7f6e\u4e3a\u6bd4\u8f93\u5165\u7ef4\u5ea6\u66f4\u5927\u7684\u503c\u4ee5\u589e\u52a0\u6a21\u578b\u7684\u8868\u8fbe\u80fd\u529b<\/span><br \/>\n        mlp_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span><br \/>\n        <span class=\"token comment\"># \u6fc0\u6d3b\u51fd\u6570&#xff0c;\u8fd9\u91cc\u9ed8\u8ba4\u4f7f\u7528GELU<\/span><br \/>\n        act<span class=\"token punctuation\">:<\/span> Type<span class=\"token punctuation\">[<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>GELU<span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u7b2c\u4e00\u4e2a\u5168\u8fde\u63a5\u5c42&#xff0c;\u5c06\u8f93\u5165\u4eceembedding_dim\u7ef4\u5ea6\u53d8\u6362\u5230mlp_dim\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>lin1 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>embedding_dim<span class=\"token punctuation\">,<\/span> mlp_dim<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u7b2c\u4e8c\u4e2a\u5168\u8fde\u63a5\u5c42&#xff0c;\u5c06mlp_dim\u7ef4\u5ea6\u7684\u7ed3\u679c\u53d8\u6362\u56deembedding_dim\u7ef4\u5ea6&#xff0c;\u4ee5\u4fdd\u6301\u4e0e\u8f93\u5165\u76f8\u540c\u7684\u7ef4\u5ea6<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>lin2 <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>mlp_dim<span class=\"token punctuation\">,<\/span> embedding_dim<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u6fc0\u6d3b\u51fd\u6570\u5b9e\u4f8b&#xff0c;\u7528\u4e8e\u5728\u5168\u8fde\u63a5\u5c42\u4e4b\u95f4\u5f15\u5165\u975e\u7ebf\u6027<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>act <span class=\"token operator\">&#061;<\/span> act<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token comment\"># \u63a5\u6536\u8f93\u5165\u5f20\u91cfx&#xff0c;\u5c06\u5176\u4f20\u9012\u7ed9lin1&#xff0c;\u7136\u540e\u5e94\u7528\u6fc0\u6d3b\u51fd\u6570act\u3002<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u6fc0\u6d3b\u51fd\u6570\u7684\u8f93\u51fa\u4f20\u9012\u7ed9lin2&#xff0c;\u5f97\u5230\u6700\u7ec8\u7684\u8f93\u51fa\u5f20\u91cf<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> self<span class=\"token punctuation\">.<\/span>lin2<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>act<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>lin1<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><\/p>\n<p>\u300cupscaled\u300d<\/p>\n<p>\u8fd9\u4e2a\u4e0a\u91c7\u6837\u8fc7\u7a0b\u5c06Transformer\u7684\u8f93\u51fa\u7279\u5f81\u56fe\u6062\u590d\u5230\u66f4\u63a5\u8fd1\u8f93\u5165\u56fe\u50cf\u7684\u5206\u8fa8\u7387&#xff0c;\u4ee5\u4fbf\u4e8e\u751f\u6210\u63a9\u6a21\u9884\u6d4b\u3002upscaled\u7684\u7ed3\u6784\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-28bqxesapdh23.png\" alt=\"\u56fe\u7247\" width=\"200\" \/><\/p>\n<p><span class=\"token comment\"># \u5728MaskDecoder\u7684__init__\u5b9a\u4e49<\/span><br \/>\n<span class=\"token comment\"># output_upscaling\u662f\u4e00\u4e2a\u5e8f\u5217\u6a21\u5757&#xff0c;\u7528\u4e8e\u4e0a\u91c7\u6837Transformer\u8f93\u51fa\u7684\u7279\u5f81\u56fe<\/span><br \/>\nself<span class=\"token punctuation\">.<\/span>output_upscaling <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>Sequential<span class=\"token punctuation\">(<\/span><br \/>\n    <span class=\"token comment\"># \u4f7f\u7528nn.ConvTranspose2d&#xff0c;\u8f93\u5165\u901a\u9053\u6570\u4e3atransformer_dim&#xff0c;\u8f93\u51fa\u901a\u9053\u6570\u4e3atransformer_dim \/\/ 4&#xff0c;\u5185\u6838\u5927\u5c0f\u4e3a2&#xff0c;\u6b65\u957f\u4e3a2<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u7279\u5f81\u56fe\u7684\u5c3a\u5bf8\u653e\u5927\u4e24\u500d&#xff0c;\u540c\u65f6\u5c06\u901a\u9053\u6570\u51cf\u534a<\/span><br \/>\n    <span class=\"token comment\"># \u5185\u6838\u5927\u5c0f\u4e3a2\u7684\u8f6c\u7f6e\u5377\u79ef\u76f8\u5f53\u4e8e\u4e0a\u91c7\u68372\u500d&#xff0c;\u6b65\u957f\u4e3a2\u786e\u4fdd\u8f93\u51fa\u5c3a\u5bf8\u7ffb\u500d<\/span><br \/>\n    nn<span class=\"token punctuation\">.<\/span>ConvTranspose2d<span class=\"token punctuation\">(<\/span>transformer_dim<span class=\"token punctuation\">,<\/span> transformer_dim <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">,<\/span> kernel_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> stride<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span>     <span class=\"token comment\">#\u8f6c\u7f6e\u5377\u79ef \u4e0a\u91c7\u68372\u500d<\/span><br \/>\n    <span class=\"token comment\"># \u5c42\u5f52\u4e00\u5316&#xff08;LayerNorm2d&#xff09;<\/span><br \/>\n    LayerNorm2d<span class=\"token punctuation\">(<\/span>transformer_dim <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u6fc0\u6d3b\u51fd\u6570<\/span><br \/>\n    activation<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u518d\u6b21\u4f7f\u7528nn.ConvTranspose2d&#xff0c;\u8f93\u5165\u901a\u9053\u6570\u4e3atransformer_dim \/\/ 4&#xff0c;\u8f93\u51fa\u901a\u9053\u6570\u4e3atransformer_dim \/\/ 8&#xff0c;\u5185\u6838\u5927\u5c0f\u4e3a2&#xff0c;\u6b65\u957f\u4e3a2\u3002\u8fd9\u4e00\u6b65\u7ee7\u7eed\u5c06\u7279\u5f81\u56fe\u7684\u5c3a\u5bf8\u653e\u5927\u4e24\u500d&#xff0c;\u540c\u65f6\u901a\u9053\u6570\u518d\u6b21\u51cf\u534a<\/span><br \/>\n    nn<span class=\"token punctuation\">.<\/span>ConvTranspose2d<span class=\"token punctuation\">(<\/span>transformer_dim <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">4<\/span><span class=\"token punctuation\">,<\/span> transformer_dim <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">8<\/span><span class=\"token punctuation\">,<\/span> kernel_size<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> stride<span class=\"token operator\">&#061;<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n    <span class=\"token comment\"># \u91cd\u590d\u6fc0\u6d3b\u51fd\u6570\u7684\u8fc7\u7a0b&#xff0c;\u4ee5\u8fdb\u4e00\u6b65\u589e\u5f3a\u975e\u7ebf\u6027\u8868\u8fbe<\/span><br \/>\n    activation<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span><br \/>\n<span class=\"token punctuation\">)<\/span><br \/>\n<span class=\"token comment\"># \u5728MaskDecoder\u7684predict_masks\u6dfb\u52a0\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\nupscaled_embedding <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>output_upscaling<span class=\"token punctuation\">(<\/span>src<span class=\"token punctuation\">)<\/span><\/p>\n<p>\u300cmask_MLP\u300d<\/p>\n<p>\u6b64\u5904\u7684MLP\u57fa\u7840\u6a21\u5757\u4e0d\u540c\u4e8eViT\u7684MLP(transformer_MLP)\u57fa\u7840\u6a21\u5757\u3002<\/p>\n<p><span class=\"token comment\"># \u5728MaskDecoder\u7684__init__\u5b9a\u4e49<\/span><br \/>\n<span class=\"token comment\"># output_hypernetworks_mlps\u662f\u4e00\u4e2ann.ModuleList&#xff0c;\u5305\u542b\u4e86\u591a\u4e2a\u591a\u5c42\u611f\u77e5\u673a&#xff08;MLP&#xff09;\u3002\u6bcf\u4e2aMLP\u7684\u76ee\u7684\u662f\u6839\u636e\u8f93\u5165\u7684mask_tokens_out\u751f\u6210\u7279\u5b9a\u63a9\u6a21\u7684\u8d85\u7f51\u7edc\u6743\u91cd<\/span><br \/>\nself<span class=\"token punctuation\">.<\/span>output_hypernetworks_mlps <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>ModuleList<span class=\"token punctuation\">(<\/span><br \/>\n    <span class=\"token punctuation\">[<\/span><br \/>\n        <span class=\"token comment\"># transformer_dim: Transformer\u7684\u8f93\u51fa\u7ef4\u5ea6&#xff0c;\u4e5f\u662f\u8f93\u5165\u5230MLP\u7684\u901a\u9053\u6570<\/span><br \/>\n        <span class=\"token comment\"># transformer_dim \/\/ 8: MLP\u7684\u8f93\u51fa\u901a\u9053\u6570&#xff0c;\u7528\u4e8e\u751f\u6210\u8d85\u7f51\u7edc\u7684\u6743\u91cd<\/span><br \/>\n        <span class=\"token comment\"># 3: MLP\u7684\u4e2d\u95f4\u5c42\u7ef4\u5ea6&#xff0c;\u7528\u4e8e\u589e\u52a0\u6a21\u578b\u7684\u8868\u8fbe\u80fd\u529b<\/span><br \/>\n        MLP<span class=\"token punctuation\">(<\/span>transformer_dim<span class=\"token punctuation\">,<\/span> transformer_dim<span class=\"token punctuation\">,<\/span> transformer_dim <span class=\"token operator\">\/\/<\/span> <span class=\"token number\">8<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">3<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">for<\/span> i <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>num_mask_tokens<span class=\"token punctuation\">)<\/span><br \/>\n    <span class=\"token punctuation\">]<\/span><br \/>\n<span class=\"token punctuation\">)<\/span><br \/>\n<span class=\"token comment\"># \u5728MaskDecoder\u7684predict_masks\u6dfb\u52a0\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\n<span class=\"token comment\"># \u5bf9\u4e8eself.num_mask_tokens\u4e2a\u63a9\u6a21token&#xff0c;\u904d\u5386output_hypernetworks_mlps\u5217\u8868<\/span><br \/>\n<span class=\"token keyword\">for<\/span> i <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">range<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>num_mask_tokens<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token comment\"># mask_tokens_out[:, i, :]: B,1,C<\/span><br \/>\n    <span class=\"token comment\"># output_hypernetworks_mlps: B,1,c<\/span><br \/>\n    <span class=\"token comment\"># \u5bf9\u6bcf\u4e2a\u63a9\u6a21token&#xff0c;\u5e94\u7528\u5bf9\u5e94\u7684MLP&#xff0c;\u8f93\u5165\u662fmask_tokens_out\u4e2d\u5bf9\u5e94\u4f4d\u7f6e\u7684\u7279\u5f81&#xff0c;\u8f93\u51fa\u4e3aB, 1, c\u5f62\u72b6\u7684\u5f20\u91cf&#xff0c;\u5176\u4e2dc\u662f\u8d85\u7f51\u7edc\u7684\u8f93\u51fa\u901a\u9053\u6570<\/span><br \/>\n    <span class=\"token comment\"># \u5c06\u6bcf\u4e2aMLP\u7684\u8f93\u51fa\u6536\u96c6\u5230hyper_in_list\u5217\u8868\u4e2d<\/span><br \/>\n    hyper_in_list<span class=\"token punctuation\">.<\/span>append<span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>output_hypernetworks_mlps<span class=\"token punctuation\">[<\/span>i<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">(<\/span>mask_tokens_out<span class=\"token punctuation\">[<\/span><span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">,<\/span> i<span class=\"token punctuation\">,<\/span> <span class=\"token punctuation\">:<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><br \/>\n<span class=\"token comment\"># B,n,c<\/span><br \/>\n<span class=\"token comment\"># \u5c06hyper_in_list\u5806\u53e0\u6210\u4e00\u4e2aB, n, c\u5f62\u72b6\u7684\u5f20\u91cfhyper_in&#xff0c;\u5176\u4e2dn\u662f\u63a9\u6a21token\u7684\u6570\u91cf<\/span><br \/>\nhyper_in <span class=\"token operator\">&#061;<\/span> torch<span class=\"token punctuation\">.<\/span>stack<span class=\"token punctuation\">(<\/span>hyper_in_list<span class=\"token punctuation\">,<\/span> dim<span class=\"token operator\">&#061;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n<span class=\"token comment\"># \u83b7\u53d6upscaled_embedding\u7684\u5f62\u72b6b, c, h, w&#xff0c;\u5176\u4e2db\u662f\u6279\u6b21\u5927\u5c0f&#xff0c;c\u662f\u901a\u9053\u6570&#xff0c;h\u548cw\u662f\u9ad8\u5ea6\u548c\u5bbd\u5ea6<\/span><br \/>\nb<span class=\"token punctuation\">,<\/span> c<span class=\"token punctuation\">,<\/span> h<span class=\"token punctuation\">,<\/span> w <span class=\"token operator\">&#061;<\/span> upscaled_embedding<span class=\"token punctuation\">.<\/span>shape<br \/>\n<span class=\"token comment\"># B,n,c \u00d7 B,c,N&#8211;&gt;B,n,h,w<\/span><br \/>\n<span class=\"token comment\"># \u6267\u884c\u77e9\u9635\u4e58\u6cd5&#xff08;&#064;\u8fd0\u7b97\u7b26&#xff09;\u5c06hyper_in&#xff08;B, n, c&#xff09;\u4e0eupscaled_embedding&#xff08;\u5728\u901a\u9053\u7ef4\u5ea6\u4e0a\u5c55\u5e73\u4e3aB, c, h * w&#xff09;\u76f8\u7ed3\u5408<\/span><br \/>\n<span class=\"token comment\"># \u8ba1\u7b97\u6bcf\u4e2a\u63a9\u6a21token\u7684\u8d85\u7f51\u7edc\u6743\u91cd\u4e0e\u4e0a\u91c7\u6837\u7279\u5f81\u56fe\u7684\u70b9\u79ef&#xff0c;\u5f97\u5230B, n, h * w\u5f62\u72b6\u7684\u5f20\u91cf<\/span><br \/>\n<span class=\"token comment\"># \u901a\u8fc7view\u64cd\u4f5c\u5c06\u7ed3\u679c\u8f6c\u6362\u56deB, n, h, w\u5f62\u72b6&#xff0c;\u751f\u6210\u4e86masks\u5f20\u91cf&#xff0c;\u8868\u793a\u6bcf\u4e2a\u63a9\u6a21token\u5bf9\u5e94\u7684\u9884\u6d4b\u63a9\u6a21<\/span><br \/>\nmasks <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">(<\/span>hyper_in &#064; upscaled_embedding<span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>b<span class=\"token punctuation\">,<\/span> c<span class=\"token punctuation\">,<\/span> h <span class=\"token operator\">*<\/span> w<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>view<span class=\"token punctuation\">(<\/span>b<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> h<span class=\"token punctuation\">,<\/span> w<span class=\"token punctuation\">)<\/span><\/p>\n<p>\u300ciou_MLP\u300d<\/p>\n<p>\u6b64\u5904\u7684MLP\u57fa\u7840\u6a21\u5757\u4e0d\u540c\u4e8eViT\u7684MLP(transformer_MLP)\u57fa\u7840\u6a21\u5757\u3002<\/p>\n<p><span class=\"token comment\"># \u5728MaskDecoder\u7684__init__\u5b9a\u4e49<\/span><br \/>\n<span class=\"token comment\"># \u4e00\u4e2a\u591a\u5c42\u611f\u77e5\u673a&#xff08;MLP&#xff09;\u6a21\u5757&#xff0c;\u5176\u76ee\u7684\u662f\u9884\u6d4b\u6bcf\u4e2a\u63a9\u6a21token\u5bf9\u5e94\u7684IoU&#xff08;Intersection over Union&#xff0c;\u4ea4\u5e76\u6bd4&#xff09;\u503c&#xff0c;\u4ee5\u8bc4\u4f30\u9884\u6d4b\u63a9\u6a21\u4e0e\u771f\u5b9e\u63a9\u6a21\u7684\u91cd\u5408\u7a0b\u5ea6<\/span><br \/>\nself<span class=\"token punctuation\">.<\/span>iou_prediction_head <span class=\"token operator\">&#061;<\/span> MLP<span class=\"token punctuation\">(<\/span><br \/>\n    <span class=\"token comment\"># transformer_dim: \u8f93\u5165\u5230MLP\u7684\u7279\u5f81\u7ef4\u5ea6&#xff0c;\u901a\u5e38\u4e0eTransformer\u7684\u8f93\u51fa\u7ef4\u5ea6\u76f8\u540c<\/span><br \/>\n    <span class=\"token comment\"># iou_head_hidden_dim: MLP\u4e2d\u95f4\u5c42\u7684\u7ef4\u5ea6&#xff0c;\u7528\u4e8e\u589e\u5f3a\u6a21\u578b\u7684\u8868\u8fbe\u80fd\u529b<\/span><br \/>\n    <span class=\"token comment\"># self.num_mask_tokens: \u8f93\u51fa\u7ef4\u5ea6&#xff0c;\u5373\u9884\u6d4b\u7684\u63a9\u6a21\u4ee4\u724c\u6570\u91cf&#xff0c;\u6bcf\u4e2a\u4ee4\u724c\u5bf9\u5e94\u4e00\u4e2aIoU\u9884\u6d4b\u503c<\/span><br \/>\n    transformer_dim<span class=\"token punctuation\">,<\/span> iou_head_hidden_dim<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>num_mask_tokens<span class=\"token punctuation\">,<\/span> iou_head_depth<br \/>\n<span class=\"token punctuation\">)<\/span><br \/>\n<span class=\"token comment\"># \u5728MaskDecoder\u7684predict_masks\u6dfb\u52a0\u4f4d\u7f6e\u7f16\u7801<\/span><br \/>\niou_pred <span class=\"token operator\">&#061;<\/span> self<span class=\"token punctuation\">.<\/span>iou_prediction_head<span class=\"token punctuation\">(<\/span>iou_token_out<span class=\"token punctuation\">)<\/span><\/p>\n<p>\u300cMaskDeco_MLP\u300d<\/p>\n<p>Mask Decoder\u4e2dMLP\u7684\u7ed3\u6784\u5982\u4e0b\u56fe\u6240\u793a\u3002<\/p>\n<p><img decoding=\"async\" src=\"2025-04-282ymsdopmzpm.png\" alt=\"\u56fe\u7247\" width=\"200\" \/><\/p>\n<p><span class=\"token triple-quoted-string string\">&#039;&#039;&#039;<br \/>\n\u5b9a\u4e49\u4e86\u4e00\u4e2a\u591a\u5c42\u611f\u77e5\u673a&#xff0c;\u5b83\u5305\u542b\u4e00\u4e2a\u53ef\u914d\u7f6e\u7684\u9690\u85cf\u5c42\u6570\u76ee\u3001\u8f93\u5165\u548c\u8f93\u51fa\u7ef4\u5ea6&#xff0c;\u5e76\u53ef\u4ee5\u9009\u62e9\u662f\u5426\u5728\u8f93\u51fa\u5c42\u5e94\u7528Sigmoid\u6fc0\u6d3b\u51fd\u6570<br \/>\n&#039;&#039;&#039;<\/span><br \/>\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">MLP<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n    <span class=\"token keyword\">def<\/span> <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><br \/>\n        self<span class=\"token punctuation\">,<\/span><br \/>\n        input_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>         <span class=\"token comment\"># \u8f93\u5165\u7279\u5f81\u7684\u7ef4\u5ea6&#xff0c;\u5373\u8f93\u5165\u5f20\u91cf\u7684\u901a\u9053\u6570<\/span><br \/>\n        hidden_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>        <span class=\"token comment\"># \u9690\u85cf\u5c42\u7684\u901a\u9053\u6570&#xff0c;\u4e2d\u95f4\u5c42\u7684\u5bbd\u5ea6<\/span><br \/>\n        output_dim<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>        <span class=\"token comment\"># \u8f93\u51fa\u7279\u5f81\u7684\u7ef4\u5ea6&#xff0c;\u5373\u8f93\u51fa\u5f20\u91cf\u7684\u901a\u9053\u6570<\/span><br \/>\n        num_layers<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">int<\/span><span class=\"token punctuation\">,<\/span>        <span class=\"token comment\"># \u591a\u5c42\u611f\u77e5\u673a\u7684\u5c42\u6570&#xff0c;\u5305\u62ec\u8f93\u5165\u5c42\u548c\u8f93\u51fa\u5c42<\/span><br \/>\n        sigmoid_output<span class=\"token punctuation\">:<\/span> <span class=\"token builtin\">bool<\/span> <span class=\"token operator\">&#061;<\/span> <span class=\"token boolean\">False<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token comment\">#  \u4e00\u4e2a\u5e03\u5c14\u503c&#xff0c;\u8868\u793a\u662f\u5426\u5728\u8f93\u51fa\u5c42\u5e94\u7528Sigmoid\u6fc0\u6d3b\u51fd\u6570&#xff0c;\u9ed8\u8ba4\u4e3aFalse<\/span><br \/>\n    <span class=\"token punctuation\">)<\/span> <span class=\"token operator\">&#8211;<\/span><span class=\"token operator\">&gt;<\/span> <span class=\"token boolean\">None<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token triple-quoted-string string\">&#039;&#039;&#039;<br \/>\n        \u5185\u90e8\u7ec4\u4ef6<br \/>\n        &#039;&#039;&#039;<\/span><br \/>\n        <span class=\"token builtin\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span>__init__<span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5b58\u50a8\u8f93\u5165\u7684\u5c42\u6570<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>num_layers <span class=\"token operator\">&#061;<\/span> num_layers<br \/>\n        <span class=\"token comment\"># \u4e00\u4e2a\u5217\u8868&#xff0c;\u5305\u542bnum_layers &#8211; 1\u4e2ahidden_dim&#xff0c;\u7528\u4e8e\u6784\u5efa\u4e2d\u95f4\u5c42\u7684\u7ebf\u6027\u53d8\u6362<\/span><br \/>\n        h <span class=\"token operator\">&#061;<\/span> <span class=\"token punctuation\">[<\/span>hidden_dim<span class=\"token punctuation\">]<\/span> <span class=\"token operator\">*<\/span> <span class=\"token punctuation\">(<\/span>num_layers <span class=\"token operator\">&#8211;<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\">#  \u4e00\u4e2ann.ModuleList&#xff0c;\u5305\u542bnum_layers\u4e2a\u7ebf\u6027\u5c42&#xff08;\u5168\u8fde\u63a5\u5c42&#xff09;&#xff0c;\u6bcf\u4e2a\u5c42\u7684\u8f93\u5165\u548c\u8f93\u51fa\u901a\u9053\u6570\u7531h\u548cinput_dim\u3001output_dim\u51b3\u5b9a<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>layers <span class=\"token operator\">&#061;<\/span> nn<span class=\"token punctuation\">.<\/span>ModuleList<span class=\"token punctuation\">(<\/span><br \/>\n            nn<span class=\"token punctuation\">.<\/span>Linear<span class=\"token punctuation\">(<\/span>n<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">)<\/span> <span class=\"token keyword\">for<\/span> n<span class=\"token punctuation\">,<\/span> k <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">zip<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>input_dim<span class=\"token punctuation\">]<\/span> <span class=\"token operator\">&#043;<\/span> h<span class=\"token punctuation\">,<\/span> h <span class=\"token operator\">&#043;<\/span> <span class=\"token punctuation\">[<\/span>output_dim<span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token punctuation\">)<\/span><br \/>\n        self<span class=\"token punctuation\">.<\/span>sigmoid_output <span class=\"token operator\">&#061;<\/span> sigmoid_output<\/p>\n<p>    <span class=\"token keyword\">def<\/span> <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n        <span class=\"token comment\"># \u5bf9\u8f93\u5165\u5f20\u91cfx&#xff0c;\u904d\u5386layers\u5217\u8868\u4e2d\u7684\u6bcf\u4e2a\u7ebf\u6027\u5c42<\/span><br \/>\n        <span class=\"token keyword\">for<\/span> i<span class=\"token punctuation\">,<\/span> layer <span class=\"token keyword\">in<\/span> <span class=\"token builtin\">enumerate<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>layers<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">:<\/span><br \/>\n            <span class=\"token comment\"># \u5982\u679c\u5f53\u524d\u5c42\u4e0d\u662f\u6700\u540e\u4e00\u5c42&#xff0c;\u5e94\u7528ReLU\u6fc0\u6d3b\u51fd\u6570&#xff08;F.relu&#xff09;<\/span><br \/>\n            x <span class=\"token operator\">&#061;<\/span> F<span class=\"token punctuation\">.<\/span>relu<span class=\"token punctuation\">(<\/span>layer<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token keyword\">if<\/span> i <span class=\"token operator\">&lt;<\/span> self<span class=\"token punctuation\">.<\/span>num_layers <span class=\"token operator\">&#8211;<\/span> <span class=\"token number\">1<\/span> <span class=\"token keyword\">else<\/span> layer<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token comment\"># \u5982\u679csigmoid_output\u4e3aTrue&#xff0c;\u6700\u540e\u5bf9\u8f93\u51fa\u5e94\u7528Sigmoid\u6fc0\u6d3b\u51fd\u6570<\/span><br \/>\n        <span class=\"token keyword\">if<\/span> self<span class=\"token punctuation\">.<\/span>sigmoid_output<span class=\"token punctuation\">:<\/span><br \/>\n            x <span class=\"token operator\">&#061;<\/span> F<span class=\"token punctuation\">.<\/span>sigmoid<span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><br \/>\n        <span class=\"token keyword\">return<\/span> x<\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u6587\u7ae0\u6d4f\u89c8\u9605\u8bfb9.3k\u6b21\uff0c\u70b9\u8d5e71\u6b21\uff0c\u6536\u85cf299\u6b21\u3002\u8282\u524d\uff0c\u6211\u4eec\u661f\u7403\u7ec4\u7ec7\u4e86\u4e00\u573a\u7b97\u6cd5\u5c97\u6280\u672f&amp;\u9762\u8bd5\u8ba8\u8bba\u4f1a\uff0c\u9080\u8bf7\u4e86\u4e00\u4e9b\u4e92\u8054\u7f51\u5927\u5382\u670b\u53cb\u3001\u53c2\u52a0\u793e\u62db\u548c\u6821\u62db\u9762\u8bd5\u7684\u540c\u5b66\u3002\u9488\u5bf9\u7b97\u6cd5\u5c97\u6280\u672f\u8d8b\u52bf\u3001\u5927\u6a21\u578b\u843d\u5730\u9879\u76ee\u7ecf\u9a8c\u5206\u4eab\u3001\u65b0\u624b\u5982\u4f55\u5165\u95e8\u7b97\u6cd5\u5c97\u3001\u8be5\u5982\u4f55\u51c6\u5907\u3001\u9762\u8bd5\u5e38\u8003\u70b9\u5206\u4eab\u7b49\u70ed\u95e8\u8bdd\u9898\u8fdb\u884c\u4e86\u6df1\u5165\u7684\u8ba8\u8bba\u3002SAM(Segment Anything Model)\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5373\u4e3a\u5206\u5272\u4e00\u5207\uff01\u8be5\u6a21\u578b\u7531Facebook\u7684Meta AI\u5b9e\u9a8c\u5ba4\uff0c\u80fd\u591f\u6839\u636e\u6587\u672c\u6307\u4ee4\u6216\u56fe\u50cf\u8bc6\u522b\uff0c\u5b9e\u73b0\u5bf9\u4efb\u610f\u7269\u4f53\u7684\u8bc6\u522b\u4e0e\u5206\u5272\u3002\u5b83\u7684\u8bde\u751f\uff0c\u65e0\u7591\u662fCV\u9886\u57df\u7684\u4e00\u6b21\u91cd\u8981\u91cc\u7a0b\u7891\u3002_sam\u6a21\u578b<\/p>\n","protected":false},"author":2,"featured_media":33788,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[1],"tags":[50,2831,132,86,427,2830],"topic":[],"class_list":["post-33808","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-server","tag-50","tag-2831","tag-132","tag-86","tag-427","tag-2830"],"yoast_head":"<!-- This site is optimized with the Yoast SEO plugin v20.3 - https:\/\/yoast.com\/wordpress\/plugins\/seo\/ -->\n<title>\u6e90\u7801\u89e3\u6790\uff1a\u4ece\u96f6\u89e3\u8bfbSAM(Segment Anything Model)\u5927\u6a21\u578b\uff01 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3<\/title>\n<meta name=\"robots\" content=\"index, follow, max-snippet:-1, max-image-preview:large, max-video-preview:-1\" \/>\n<link rel=\"canonical\" href=\"https:\/\/www.wsisp.com\/helps\/33808.html\" \/>\n<meta property=\"og:locale\" content=\"zh_CN\" \/>\n<meta property=\"og:type\" content=\"article\" \/>\n<meta property=\"og:title\" content=\"\u6e90\u7801\u89e3\u6790\uff1a\u4ece\u96f6\u89e3\u8bfbSAM(Segment Anything Model)\u5927\u6a21\u578b\uff01 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3\" \/>\n<meta property=\"og:description\" content=\"\u6587\u7ae0\u6d4f\u89c8\u9605\u8bfb9.3k\u6b21\uff0c\u70b9\u8d5e71\u6b21\uff0c\u6536\u85cf299\u6b21\u3002\u8282\u524d\uff0c\u6211\u4eec\u661f\u7403\u7ec4\u7ec7\u4e86\u4e00\u573a\u7b97\u6cd5\u5c97\u6280\u672f&amp;\u9762\u8bd5\u8ba8\u8bba\u4f1a\uff0c\u9080\u8bf7\u4e86\u4e00\u4e9b\u4e92\u8054\u7f51\u5927\u5382\u670b\u53cb\u3001\u53c2\u52a0\u793e\u62db\u548c\u6821\u62db\u9762\u8bd5\u7684\u540c\u5b66\u3002\u9488\u5bf9\u7b97\u6cd5\u5c97\u6280\u672f\u8d8b\u52bf\u3001\u5927\u6a21\u578b\u843d\u5730\u9879\u76ee\u7ecf\u9a8c\u5206\u4eab\u3001\u65b0\u624b\u5982\u4f55\u5165\u95e8\u7b97\u6cd5\u5c97\u3001\u8be5\u5982\u4f55\u51c6\u5907\u3001\u9762\u8bd5\u5e38\u8003\u70b9\u5206\u4eab\u7b49\u70ed\u95e8\u8bdd\u9898\u8fdb\u884c\u4e86\u6df1\u5165\u7684\u8ba8\u8bba\u3002SAM(Segment Anything Model)\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5373\u4e3a\u5206\u5272\u4e00\u5207\uff01\u8be5\u6a21\u578b\u7531Facebook\u7684Meta AI\u5b9e\u9a8c\u5ba4\uff0c\u80fd\u591f\u6839\u636e\u6587\u672c\u6307\u4ee4\u6216\u56fe\u50cf\u8bc6\u522b\uff0c\u5b9e\u73b0\u5bf9\u4efb\u610f\u7269\u4f53\u7684\u8bc6\u522b\u4e0e\u5206\u5272\u3002\u5b83\u7684\u8bde\u751f\uff0c\u65e0\u7591\u662fCV\u9886\u57df\u7684\u4e00\u6b21\u91cd\u8981\u91cc\u7a0b\u7891\u3002_sam\u6a21\u578b\" \/>\n<meta property=\"og:url\" content=\"https:\/\/www.wsisp.com\/helps\/33808.html\" \/>\n<meta property=\"og:site_name\" content=\"\u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3\" \/>\n<meta property=\"article:published_time\" content=\"2025-04-28T07:49:45+00:00\" \/>\n<meta property=\"og:image\" content=\"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074938-680f3312d5fa7.png\" \/>\n<meta name=\"author\" content=\"admin\" \/>\n<meta name=\"twitter:card\" content=\"summary_large_image\" \/>\n<meta name=\"twitter:label1\" content=\"\u4f5c\u8005\" \/>\n\t<meta name=\"twitter:data1\" content=\"admin\" \/>\n\t<meta name=\"twitter:label2\" content=\"\u9884\u8ba1\u9605\u8bfb\u65f6\u95f4\" \/>\n\t<meta name=\"twitter:data2\" content=\"29 \u5206\" \/>\n<script type=\"application\/ld+json\" class=\"yoast-schema-graph\">{\"@context\":\"https:\/\/schema.org\",\"@graph\":[{\"@type\":\"WebPage\",\"@id\":\"https:\/\/www.wsisp.com\/helps\/33808.html\",\"url\":\"https:\/\/www.wsisp.com\/helps\/33808.html\",\"name\":\"\u6e90\u7801\u89e3\u6790\uff1a\u4ece\u96f6\u89e3\u8bfbSAM(Segment Anything Model)\u5927\u6a21\u578b\uff01 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3\",\"isPartOf\":{\"@id\":\"https:\/\/www.wsisp.com\/helps\/#website\"},\"datePublished\":\"2025-04-28T07:49:45+00:00\",\"dateModified\":\"2025-04-28T07:49:45+00:00\",\"author\":{\"@id\":\"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/358e386c577a3ab51c4493330a20ad41\"},\"breadcrumb\":{\"@id\":\"https:\/\/www.wsisp.com\/helps\/33808.html#breadcrumb\"},\"inLanguage\":\"zh-Hans\",\"potentialAction\":[{\"@type\":\"ReadAction\",\"target\":[\"https:\/\/www.wsisp.com\/helps\/33808.html\"]}]},{\"@type\":\"BreadcrumbList\",\"@id\":\"https:\/\/www.wsisp.com\/helps\/33808.html#breadcrumb\",\"itemListElement\":[{\"@type\":\"ListItem\",\"position\":1,\"name\":\"\u9996\u9875\",\"item\":\"https:\/\/www.wsisp.com\/helps\"},{\"@type\":\"ListItem\",\"position\":2,\"name\":\"\u6e90\u7801\u89e3\u6790\uff1a\u4ece\u96f6\u89e3\u8bfbSAM(Segment Anything Model)\u5927\u6a21\u578b\uff01\"}]},{\"@type\":\"WebSite\",\"@id\":\"https:\/\/www.wsisp.com\/helps\/#website\",\"url\":\"https:\/\/www.wsisp.com\/helps\/\",\"name\":\"\u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3\",\"description\":\"\u9999\u6e2f\u670d\u52a1\u5668_\u9999\u6e2f\u4e91\u670d\u52a1\u5668\u8d44\u8baf_\u670d\u52a1\u5668\u5e2e\u52a9\u6587\u6863_\u670d\u52a1\u5668\u6559\u7a0b\",\"potentialAction\":[{\"@type\":\"SearchAction\",\"target\":{\"@type\":\"EntryPoint\",\"urlTemplate\":\"https:\/\/www.wsisp.com\/helps\/?s={search_term_string}\"},\"query-input\":\"required name=search_term_string\"}],\"inLanguage\":\"zh-Hans\"},{\"@type\":\"Person\",\"@id\":\"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/358e386c577a3ab51c4493330a20ad41\",\"name\":\"admin\",\"image\":{\"@type\":\"ImageObject\",\"inLanguage\":\"zh-Hans\",\"@id\":\"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/image\/\",\"url\":\"https:\/\/gravatar.wp-china-yes.net\/avatar\/?s=96&d=mystery\",\"contentUrl\":\"https:\/\/gravatar.wp-china-yes.net\/avatar\/?s=96&d=mystery\",\"caption\":\"admin\"},\"sameAs\":[\"http:\/\/wp.wsisp.com\"],\"url\":\"https:\/\/www.wsisp.com\/helps\/author\/admin\"}]}<\/script>\n<!-- \/ Yoast SEO plugin. -->","yoast_head_json":{"title":"\u6e90\u7801\u89e3\u6790\uff1a\u4ece\u96f6\u89e3\u8bfbSAM(Segment Anything Model)\u5927\u6a21\u578b\uff01 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3","robots":{"index":"index","follow":"follow","max-snippet":"max-snippet:-1","max-image-preview":"max-image-preview:large","max-video-preview":"max-video-preview:-1"},"canonical":"https:\/\/www.wsisp.com\/helps\/33808.html","og_locale":"zh_CN","og_type":"article","og_title":"\u6e90\u7801\u89e3\u6790\uff1a\u4ece\u96f6\u89e3\u8bfbSAM(Segment Anything Model)\u5927\u6a21\u578b\uff01 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3","og_description":"\u6587\u7ae0\u6d4f\u89c8\u9605\u8bfb9.3k\u6b21\uff0c\u70b9\u8d5e71\u6b21\uff0c\u6536\u85cf299\u6b21\u3002\u8282\u524d\uff0c\u6211\u4eec\u661f\u7403\u7ec4\u7ec7\u4e86\u4e00\u573a\u7b97\u6cd5\u5c97\u6280\u672f&amp;\u9762\u8bd5\u8ba8\u8bba\u4f1a\uff0c\u9080\u8bf7\u4e86\u4e00\u4e9b\u4e92\u8054\u7f51\u5927\u5382\u670b\u53cb\u3001\u53c2\u52a0\u793e\u62db\u548c\u6821\u62db\u9762\u8bd5\u7684\u540c\u5b66\u3002\u9488\u5bf9\u7b97\u6cd5\u5c97\u6280\u672f\u8d8b\u52bf\u3001\u5927\u6a21\u578b\u843d\u5730\u9879\u76ee\u7ecf\u9a8c\u5206\u4eab\u3001\u65b0\u624b\u5982\u4f55\u5165\u95e8\u7b97\u6cd5\u5c97\u3001\u8be5\u5982\u4f55\u51c6\u5907\u3001\u9762\u8bd5\u5e38\u8003\u70b9\u5206\u4eab\u7b49\u70ed\u95e8\u8bdd\u9898\u8fdb\u884c\u4e86\u6df1\u5165\u7684\u8ba8\u8bba\u3002SAM(Segment Anything Model)\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5373\u4e3a\u5206\u5272\u4e00\u5207\uff01\u8be5\u6a21\u578b\u7531Facebook\u7684Meta AI\u5b9e\u9a8c\u5ba4\uff0c\u80fd\u591f\u6839\u636e\u6587\u672c\u6307\u4ee4\u6216\u56fe\u50cf\u8bc6\u522b\uff0c\u5b9e\u73b0\u5bf9\u4efb\u610f\u7269\u4f53\u7684\u8bc6\u522b\u4e0e\u5206\u5272\u3002\u5b83\u7684\u8bde\u751f\uff0c\u65e0\u7591\u662fCV\u9886\u57df\u7684\u4e00\u6b21\u91cd\u8981\u91cc\u7a0b\u7891\u3002_sam\u6a21\u578b","og_url":"https:\/\/www.wsisp.com\/helps\/33808.html","og_site_name":"\u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3","article_published_time":"2025-04-28T07:49:45+00:00","og_image":[{"url":"https:\/\/www.wsisp.com\/helps\/wp-content\/uploads\/2025\/04\/20250428074938-680f3312d5fa7.png"}],"author":"admin","twitter_card":"summary_large_image","twitter_misc":{"\u4f5c\u8005":"admin","\u9884\u8ba1\u9605\u8bfb\u65f6\u95f4":"29 \u5206"},"schema":{"@context":"https:\/\/schema.org","@graph":[{"@type":"WebPage","@id":"https:\/\/www.wsisp.com\/helps\/33808.html","url":"https:\/\/www.wsisp.com\/helps\/33808.html","name":"\u6e90\u7801\u89e3\u6790\uff1a\u4ece\u96f6\u89e3\u8bfbSAM(Segment Anything Model)\u5927\u6a21\u578b\uff01 - \u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3","isPartOf":{"@id":"https:\/\/www.wsisp.com\/helps\/#website"},"datePublished":"2025-04-28T07:49:45+00:00","dateModified":"2025-04-28T07:49:45+00:00","author":{"@id":"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/358e386c577a3ab51c4493330a20ad41"},"breadcrumb":{"@id":"https:\/\/www.wsisp.com\/helps\/33808.html#breadcrumb"},"inLanguage":"zh-Hans","potentialAction":[{"@type":"ReadAction","target":["https:\/\/www.wsisp.com\/helps\/33808.html"]}]},{"@type":"BreadcrumbList","@id":"https:\/\/www.wsisp.com\/helps\/33808.html#breadcrumb","itemListElement":[{"@type":"ListItem","position":1,"name":"\u9996\u9875","item":"https:\/\/www.wsisp.com\/helps"},{"@type":"ListItem","position":2,"name":"\u6e90\u7801\u89e3\u6790\uff1a\u4ece\u96f6\u89e3\u8bfbSAM(Segment Anything Model)\u5927\u6a21\u578b\uff01"}]},{"@type":"WebSite","@id":"https:\/\/www.wsisp.com\/helps\/#website","url":"https:\/\/www.wsisp.com\/helps\/","name":"\u7f51\u7855\u4e92\u8054\u5e2e\u52a9\u4e2d\u5fc3","description":"\u9999\u6e2f\u670d\u52a1\u5668_\u9999\u6e2f\u4e91\u670d\u52a1\u5668\u8d44\u8baf_\u670d\u52a1\u5668\u5e2e\u52a9\u6587\u6863_\u670d\u52a1\u5668\u6559\u7a0b","potentialAction":[{"@type":"SearchAction","target":{"@type":"EntryPoint","urlTemplate":"https:\/\/www.wsisp.com\/helps\/?s={search_term_string}"},"query-input":"required name=search_term_string"}],"inLanguage":"zh-Hans"},{"@type":"Person","@id":"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/358e386c577a3ab51c4493330a20ad41","name":"admin","image":{"@type":"ImageObject","inLanguage":"zh-Hans","@id":"https:\/\/www.wsisp.com\/helps\/#\/schema\/person\/image\/","url":"https:\/\/gravatar.wp-china-yes.net\/avatar\/?s=96&d=mystery","contentUrl":"https:\/\/gravatar.wp-china-yes.net\/avatar\/?s=96&d=mystery","caption":"admin"},"sameAs":["http:\/\/wp.wsisp.com"],"url":"https:\/\/www.wsisp.com\/helps\/author\/admin"}]}},"_links":{"self":[{"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/posts\/33808","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/users\/2"}],"replies":[{"embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/comments?post=33808"}],"version-history":[{"count":0,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/posts\/33808\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/media\/33788"}],"wp:attachment":[{"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/media?parent=33808"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/categories?post=33808"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/tags?post=33808"},{"taxonomy":"topic","embeddable":true,"href":"https:\/\/www.wsisp.com\/helps\/wp-json\/wp\/v2\/topic?post=33808"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}