|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta http-equiv="content-type" content="text/html;charset=utf-8"/> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"/> |
|
<meta name="description" content="UNet model for Denoising Diffusion Probabilistic Models (DDPM)"/> |
|
|
|
<meta name="twitter:card" content="summary"/> |
|
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/> |
|
<meta name="twitter:title" content="U-Net model for Denoising Diffusion Probabilistic Models (DDPM)"/> |
|
<meta name="twitter:description" content="UNet model for Denoising Diffusion Probabilistic Models (DDPM)"/> |
|
<meta name="twitter:site" content="@labmlai"/> |
|
<meta name="twitter:creator" content="@labmlai"/> |
|
|
|
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/unet.html"/> |
|
<meta property="og:title" content="U-Net model for Denoising Diffusion Probabilistic Models (DDPM)"/> |
|
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/> |
|
<meta property="og:site_name" content="U-Net model for Denoising Diffusion Probabilistic Models (DDPM)"/> |
|
<meta property="og:type" content="object"/> |
|
<meta property="og:title" content="U-Net model for Denoising Diffusion Probabilistic Models (DDPM)"/> |
|
<meta property="og:description" content="UNet model for Denoising Diffusion Probabilistic Models (DDPM)"/> |
|
|
|
<title>U-Net model for Denoising Diffusion Probabilistic Models (DDPM)</title> |
|
<link rel="shortcut icon" href="/icon.png"/> |
|
<link rel="stylesheet" href="../../pylit.css?v=1"> |
|
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/unet.html"/> |
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous"> |
|
|
|
|
|
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script> |
|
<script> |
|
window.dataLayer = window.dataLayer || []; |
|
|
|
function gtag() { |
|
dataLayer.push(arguments); |
|
} |
|
|
|
gtag('js', new Date()); |
|
|
|
gtag('config', 'G-4V3HC8HBLH'); |
|
</script> |
|
</head> |
|
<body> |
|
<div id='container'> |
|
<div id="background"></div> |
|
<div class='section'> |
|
<div class='docs'> |
|
<p> |
|
<a class="parent" href="/">home</a> |
|
<a class="parent" href="../index.html">diffusion</a> |
|
<a class="parent" href="index.html">ddpm</a> |
|
</p> |
|
<p> |
|
<a href="https://github.com/sponsors/labmlai" target="_blank"> |
|
<img alt="Sponsor" |
|
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86" |
|
style="max-width:100%;"/></a> |
|
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank"> |
|
<img alt="Github" |
|
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social" |
|
style="max-width:100%;"/></a> |
|
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank"> |
|
<img alt="Twitter" |
|
src="https://img.shields.io/twitter/follow/labmlai?style=social" |
|
style="max-width:100%;"/></a> |
|
</p> |
|
<p> |
|
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/unet.py" target="_blank"> |
|
View code on Github</a> |
|
</p> |
|
</div> |
|
</div> |
|
<div class='section' id='section-0'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-0'>#</a> |
|
</div> |
|
<h1>U-Net model for <a href="index.html">Denoising Diffusion Probabilistic Models (DDPM)</a></h1> |
|
<p>This is a <a href="../../unet/index.html">U-Net</a> based model to predict noise <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord" style="color:lightgreen"><span class="mord mathnormal" style="">ϵ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span><span class="mclose">)</span></span></span></span></span>.</p> |
|
<p>U-Net is a gets it's name from the U shape in the model diagram. It processes a given image by progressively lowering (halving) the feature map resolution and then increasing the resolution. There are pass-through connection at each resolution.</p> |
|
<p><img alt="U-Net diagram from paper" src="../../unet/unet.png"></p> |
|
<p>This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention) and also adds time-step embeddings <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span>.</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">24</span><span></span><span class="kn">import</span> <span class="nn">math</span> |
|
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span> |
|
<span class="lineno">26</span> |
|
<span class="lineno">27</span><span class="kn">import</span> <span class="nn">torch</span> |
|
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span> |
|
<span class="lineno">29</span> |
|
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-1'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-1'>#</a> |
|
</div> |
|
<h3>Swish actiavation function</h3> |
|
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44445em;vertical-align:0em;"></span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span></span></p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">33</span><span class="k">class</span> <span class="nc">Swish</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-2'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-2'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">40</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> |
|
<span class="lineno">41</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-3'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-3'>#</a> |
|
</div> |
|
<h3>Embeddings for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span></h3> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">44</span><span class="k">class</span> <span class="nc">TimeEmbedding</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-4'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-4'>#</a> |
|
</div> |
|
<ul><li><code class="highlight"><span></span><span class="n">n_channels</span></code> |
|
is the number of dimensions in the embedding</li></ul> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-5'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-5'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">53</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
|
<span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span> <span class="o">=</span> <span class="n">n_channels</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-6'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-6'>#</a> |
|
</div> |
|
<p>First linear layer </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">56</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span> <span class="o">//</span> <span class="mi">4</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-7'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-7'>#</a> |
|
</div> |
|
<p>Activation </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">58</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-8'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-8'>#</a> |
|
</div> |
|
<p>Second linear layer </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">60</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-9'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-9'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-10'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-10'>#</a> |
|
</div> |
|
<p>Create sinusoidal position embeddings <a href="../../transformers/positional_encoding.html">same as those from the transformer</a></p> |
|
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:6.600059999999999em;vertical-align:-3.0500299999999996em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.5500299999999996em;"><span style="top:-5.55003em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.4231360000000004em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="">t</span></span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.412972em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.2500000000000004em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.4231360000000004em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="">t</span></span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.412972em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:3.0500299999999996em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.5500299999999996em;"><span style="top:-5.55003em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathnormal">s</span><span class="mord mathnormal">in</span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.29208em;"><span style="top:-2.121225em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1000</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9887749999999998em;"><span style="top:-3.3902150000000004em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8550857142857142em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqi" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.2255em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.40352142857142853em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8787749999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="delimsizing size4">)</span></span></span></span><span style="top:-2.2500000000000004em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathnormal">cos</span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.29208em;"><span style="top:-2.121225em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1000</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9887749999999998em;"><span style="top:-3.3902150000000004em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8550857142857142em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqi" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.2255em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.40352142857142853em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8787749999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="delimsizing size4">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:3.0500299999999996em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>where <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="">d</span></span></span></span></span></span> is <code class="highlight"><span></span><span class="n">half_dim</span></code> |
|
</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">half_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span> <span class="o">//</span> <span class="mi">8</span> |
|
<span class="lineno">73</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">half_dim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> |
|
<span class="lineno">74</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">half_dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">t</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="o">*</span> <span class="o">-</span><span class="n">emb</span><span class="p">)</span> |
|
<span class="lineno">75</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">emb</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> |
|
<span class="lineno">76</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">emb</span><span class="o">.</span><span class="n">sin</span><span class="p">(),</span> <span class="n">emb</span><span class="o">.</span><span class="n">cos</span><span class="p">()),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-11'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-11'>#</a> |
|
</div> |
|
<p>Transform with the MLP </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lin1</span><span class="p">(</span><span class="n">emb</span><span class="p">))</span> |
|
<span class="lineno">80</span> <span class="n">emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin2</span><span class="p">(</span><span class="n">emb</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-12'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-12'>#</a> |
|
</div> |
|
<p> </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">return</span> <span class="n">emb</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-13'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-13'>#</a> |
|
</div> |
|
<h3>Residual block</h3> |
|
<p>A residual block has two convolution layers with group normalization. Each resolution is processed with two residual blocks.</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">86</span><span class="k">class</span> <span class="nc">ResidualBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-14'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-14'>#</a> |
|
</div> |
|
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code> |
|
is the number of input channels </li> |
|
<li><code class="highlight"><span></span><span class="n">out_channels</span></code> |
|
is the number of input channels </li> |
|
<li><code class="highlight"><span></span><span class="n">time_channels</span></code> |
|
is the number channels in the time step (<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span>) embeddings </li> |
|
<li><code class="highlight"><span></span><span class="n">n_groups</span></code> |
|
is the number of groups for <a href="../../normalization/group_norm/index.html">group normalization</a> </li> |
|
<li><code class="highlight"><span></span><span class="n">dropout</span></code> |
|
is the dropout rate</li></ul> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">94</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> |
|
<span class="lineno">95</span> <span class="n">n_groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-15'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-15'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">103</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-16'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-16'>#</a> |
|
</div> |
|
<p>Group normalization and the first convolution layer </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">105</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">)</span> |
|
<span class="lineno">106</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span> |
|
<span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-17'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-17'>#</a> |
|
</div> |
|
<p>Group normalization and the second convolution layer </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">)</span> |
|
<span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span> |
|
<span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-18'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-18'>#</a> |
|
</div> |
|
<p>If the number of input channels is not equal to the number of output channels we have to project the shortcut connection </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">if</span> <span class="n">in_channels</span> <span class="o">!=</span> <span class="n">out_channels</span><span class="p">:</span> |
|
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> |
|
<span class="lineno">118</span> <span class="k">else</span><span class="p">:</span> |
|
<span class="lineno">119</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-19'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-19'>#</a> |
|
</div> |
|
<p>Linear layer for time embeddings </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">time_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">)</span> |
|
<span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_act</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span> |
|
<span class="lineno">124</span> |
|
<span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-20'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-20'>#</a> |
|
</div> |
|
<ul><li><code class="highlight"><span></span><span class="n">x</span></code> |
|
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code> |
|
</li> |
|
<li><code class="highlight"><span></span><span class="n">t</span></code> |
|
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">]</span></code> |
|
</li></ul> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-21'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-21'>#</a> |
|
</div> |
|
<p>First convolution layer </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-22'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-22'>#</a> |
|
</div> |
|
<p>Add time embeddings </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">135</span> <span class="n">h</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">time_act</span><span class="p">(</span><span class="n">t</span><span class="p">))[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-23'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-23'>#</a> |
|
</div> |
|
<p>Second convolution layer </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">h</span><span class="p">))))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-24'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-24'>#</a> |
|
</div> |
|
<p>Add the shortcut connection and return </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">140</span> <span class="k">return</span> <span class="n">h</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-25'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-25'>#</a> |
|
</div> |
|
<h3>Attention block</h3> |
|
<p>This is similar to <a href="../../transformers/mha.html">transformer multi-head attention</a>.</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">143</span><span class="k">class</span> <span class="nc">AttentionBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-26'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-26'>#</a> |
|
</div> |
|
<ul><li><code class="highlight"><span></span><span class="n">n_channels</span></code> |
|
is the number of channels in the input </li> |
|
<li><code class="highlight"><span></span><span class="n">n_heads</span></code> |
|
is the number of heads in multi-head attention </li> |
|
<li><code class="highlight"><span></span><span class="n">d_k</span></code> |
|
is the number of dimensions in each head </li> |
|
<li><code class="highlight"><span></span><span class="n">n_groups</span></code> |
|
is the number of groups for <a href="../../normalization/group_norm/index.html">group normalization</a></li></ul> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">150</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-27'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-27'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">157</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-28'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-28'>#</a> |
|
</div> |
|
<p>Default <code class="highlight"><span></span><span class="n">d_k</span></code> |
|
</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">160</span> <span class="k">if</span> <span class="n">d_k</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |
|
<span class="lineno">161</span> <span class="n">d_k</span> <span class="o">=</span> <span class="n">n_channels</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-29'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-29'>#</a> |
|
</div> |
|
<p>Normalization layer </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-30'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-30'>#</a> |
|
</div> |
|
<p>Projections for query, key and values </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">165</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-31'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-31'>#</a> |
|
</div> |
|
<p>Linear layer for final transformation </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">167</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-32'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-32'>#</a> |
|
</div> |
|
<p>Scale for dot-product attention </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">169</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">d_k</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-33'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-33'>#</a> |
|
</div> |
|
<p> </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">171</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span> |
|
<span class="lineno">172</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_k</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-34'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-34'>#</a> |
|
</div> |
|
<ul><li><code class="highlight"><span></span><span class="n">x</span></code> |
|
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code> |
|
</li> |
|
<li><code class="highlight"><span></span><span class="n">t</span></code> |
|
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">]</span></code> |
|
</li></ul> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-35'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-35'>#</a> |
|
</div> |
|
<p><code class="highlight"><span></span><span class="n">t</span></code> |
|
is not used, but it's kept in the arguments because for the attention layer function signature to match with <code class="highlight"><span></span><span class="n">ResidualBlock</span></code> |
|
. </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">_</span> <span class="o">=</span> <span class="n">t</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-36'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-36'>#</a> |
|
</div> |
|
<p>Get shape </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-37'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-37'>#</a> |
|
</div> |
|
<p>Change <code class="highlight"><span></span><span class="n">x</span></code> |
|
to shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">]</span></code> |
|
</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">185</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-38'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-38'>#</a> |
|
</div> |
|
<p>Get query, key, and values (concatenated) and shape it to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code> |
|
</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">187</span> <span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-39'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-39'>#</a> |
|
</div> |
|
<p>Split query, key, and values. Each of them will have shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code> |
|
</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-40'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-40'>#</a> |
|
</div> |
|
<p>Calculate scaled dot-product <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.633028em;vertical-align:-0.538em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqi" style="">d</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702 |
|
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14 |
|
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54 |
|
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10 |
|
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429 |
|
c69,-144,104.5,-217.7,106.5,-221 |
|
l0 -0 |
|
c5.3,-9.3,12,-14,20,-14 |
|
H400000v40H845.2724 |
|
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7 |
|
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z |
|
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">Q</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style="">⊤</span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span> </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bihd,bjhd->bijh'</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-41'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-41'>#</a> |
|
</div> |
|
<p>Softmax along the sequence dimension <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mord mathnormal">so</span><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span><span class="mord mathnormal">ma</span><span class="mord mathnormal">x</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqi" style="">d</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702 |
|
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14 |
|
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54 |
|
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10 |
|
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429 |
|
c69,-144,104.5,-217.7,106.5,-221 |
|
l0 -0 |
|
c5.3,-9.3,12,-14,20,-14 |
|
H400000v40H845.2724 |
|
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7 |
|
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z |
|
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">Q</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style="">⊤</span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mord"><span class="delimsizing size4">)</span></span></span></span></span></span> </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-42'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-42'>#</a> |
|
</div> |
|
<p>Multiply by values </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">res</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bijh,bjhd->bihd'</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-43'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-43'>#</a> |
|
</div> |
|
<p>Reshape to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code> |
|
</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">res</span> <span class="o">=</span> <span class="n">res</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-44'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-44'>#</a> |
|
</div> |
|
<p>Transform to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">]</span></code> |
|
</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-45'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-45'>#</a> |
|
</div> |
|
<p>Add skip connection </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">202</span> <span class="n">res</span> <span class="o">+=</span> <span class="n">x</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-46'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-46'>#</a> |
|
</div> |
|
<p>Change to shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code> |
|
</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">res</span> <span class="o">=</span> <span class="n">res</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-47'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-47'>#</a> |
|
</div> |
|
<p> </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">208</span> <span class="k">return</span> <span class="n">res</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-48'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-48'>#</a> |
|
</div> |
|
<h3>Down block</h3> |
|
<p>This combines <code class="highlight"><span></span><span class="n">ResidualBlock</span></code> |
|
and <code class="highlight"><span></span><span class="n">AttentionBlock</span></code> |
|
. These are used in the first half of U-Net at each resolution.</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">211</span><span class="k">class</span> <span class="nc">DownBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-49'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-49'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">218</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">has_attn</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> |
|
<span class="lineno">219</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
|
<span class="lineno">220</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span> |
|
<span class="lineno">221</span> <span class="k">if</span> <span class="n">has_attn</span><span class="p">:</span> |
|
<span class="lineno">222</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">AttentionBlock</span><span class="p">(</span><span class="n">out_channels</span><span class="p">)</span> |
|
<span class="lineno">223</span> <span class="k">else</span><span class="p">:</span> |
|
<span class="lineno">224</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-50'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-50'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">226</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> |
|
<span class="lineno">227</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span> |
|
<span class="lineno">228</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> |
|
<span class="lineno">229</span> <span class="k">return</span> <span class="n">x</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-51'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-51'>#</a> |
|
</div> |
|
<h3>Up block</h3> |
|
<p>This combines <code class="highlight"><span></span><span class="n">ResidualBlock</span></code> |
|
and <code class="highlight"><span></span><span class="n">AttentionBlock</span></code> |
|
. These are used in the second half of U-Net at each resolution.</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">232</span><span class="k">class</span> <span class="nc">UpBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-52'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-52'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">239</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">has_attn</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> |
|
<span class="lineno">240</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-53'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-53'>#</a> |
|
</div> |
|
<p>The input has <code class="highlight"><span></span><span class="n">in_channels</span> <span class="o">+</span> <span class="n">out_channels</span></code> |
|
because we concatenate the output of the same resolution from the first half of the U-Net </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">243</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span> <span class="o">+</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span> |
|
<span class="lineno">244</span> <span class="k">if</span> <span class="n">has_attn</span><span class="p">:</span> |
|
<span class="lineno">245</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">AttentionBlock</span><span class="p">(</span><span class="n">out_channels</span><span class="p">)</span> |
|
<span class="lineno">246</span> <span class="k">else</span><span class="p">:</span> |
|
<span class="lineno">247</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-54'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-54'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">249</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> |
|
<span class="lineno">250</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span> |
|
<span class="lineno">251</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> |
|
<span class="lineno">252</span> <span class="k">return</span> <span class="n">x</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-55'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-55'>#</a> |
|
</div> |
|
<h3>Middle block</h3> |
|
<p>It combines a <code class="highlight"><span></span><span class="n">ResidualBlock</span></code> |
|
, <code class="highlight"><span></span><span class="n">AttentionBlock</span></code> |
|
, followed by another <code class="highlight"><span></span><span class="n">ResidualBlock</span></code> |
|
. This block is applied at the lowest resolution of the U-Net.</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">255</span><span class="k">class</span> <span class="nc">MiddleBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-56'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-56'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">263</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> |
|
<span class="lineno">264</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
|
<span class="lineno">265</span> <span class="bp">self</span><span class="o">.</span><span class="n">res1</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span> |
|
<span class="lineno">266</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">AttentionBlock</span><span class="p">(</span><span class="n">n_channels</span><span class="p">)</span> |
|
<span class="lineno">267</span> <span class="bp">self</span><span class="o">.</span><span class="n">res2</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-57'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-57'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">269</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> |
|
<span class="lineno">270</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span> |
|
<span class="lineno">271</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> |
|
<span class="lineno">272</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res2</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span> |
|
<span class="lineno">273</span> <span class="k">return</span> <span class="n">x</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-58'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-58'>#</a> |
|
</div> |
|
<h3>Scale up the feature map by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">2</span><span class="mord">×</span></span></span></span></span></h3> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">276</span><span class="k">class</span> <span class="nc">Upsample</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-59'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-59'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">281</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">):</span> |
|
<span class="lineno">282</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
|
<span class="lineno">283</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose2d</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-60'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-60'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">285</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-61'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-61'>#</a> |
|
</div> |
|
<p><code class="highlight"><span></span><span class="n">t</span></code> |
|
is not used, but it's kept in the arguments because for the attention layer function signature to match with <code class="highlight"><span></span><span class="n">ResidualBlock</span></code> |
|
. </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">288</span> <span class="n">_</span> <span class="o">=</span> <span class="n">t</span> |
|
<span class="lineno">289</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-62'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-62'>#</a> |
|
</div> |
|
<h3>Scale down the feature map by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.190108em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord">×</span></span></span></span></span></h3> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">292</span><span class="k">class</span> <span class="nc">Downsample</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-63'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-63'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">297</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">):</span> |
|
<span class="lineno">298</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
|
<span class="lineno">299</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-64'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-64'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">301</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-65'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-65'>#</a> |
|
</div> |
|
<p><code class="highlight"><span></span><span class="n">t</span></code> |
|
is not used, but it's kept in the arguments because for the attention layer function signature to match with <code class="highlight"><span></span><span class="n">ResidualBlock</span></code> |
|
. </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">304</span> <span class="n">_</span> <span class="o">=</span> <span class="n">t</span> |
|
<span class="lineno">305</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-66'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-66'>#</a> |
|
</div> |
|
<h2>U-Net</h2> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">308</span><span class="k">class</span> <span class="nc">UNet</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-67'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-67'>#</a> |
|
</div> |
|
<ul><li><code class="highlight"><span></span><span class="n">image_channels</span></code> |
|
is the number of channels in the image. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span></span></span></span></span> for RGB. </li> |
|
<li><code class="highlight"><span></span><span class="n">n_channels</span></code> |
|
is number of channels in the initial feature map that we transform the image into </li> |
|
<li><code class="highlight"><span></span><span class="n">ch_mults</span></code> |
|
is the list of channel numbers at each resolution. The number of channels is <code class="highlight"><span></span><span class="n">ch_mults</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">n_channels</span></code> |
|
</li> |
|
<li><code class="highlight"><span></span><span class="n">is_attn</span></code> |
|
is a list of booleans that indicate whether to use attention at each resolution </li> |
|
<li><code class="highlight"><span></span><span class="n">n_blocks</span></code> |
|
is the number of <code class="highlight"><span></span><span class="n">UpDownBlocks</span></code> |
|
at each resolution</li></ul> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">313</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span> |
|
<span class="lineno">314</span> <span class="n">ch_mults</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> |
|
<span class="lineno">315</span> <span class="n">is_attn</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">bool</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span> |
|
<span class="lineno">316</span> <span class="n">n_blocks</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-68'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-68'>#</a> |
|
</div> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">324</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-69'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-69'>#</a> |
|
</div> |
|
<p>Number of resolutions </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">327</span> <span class="n">n_resolutions</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">ch_mults</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-70'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-70'>#</a> |
|
</div> |
|
<p>Project image into feature map </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">330</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">image_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-71'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-71'>#</a> |
|
</div> |
|
<p>Time embedding layer. Time embedding has <code class="highlight"><span></span><span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span></code> |
|
channels </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">333</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span> <span class="o">=</span> <span class="n">TimeEmbedding</span><span class="p">(</span><span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-72'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-72'>#</a> |
|
</div> |
|
<h4>First half of U-Net - decreasing resolution</h4> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">336</span> <span class="n">down</span> <span class="o">=</span> <span class="p">[]</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-73'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-73'>#</a> |
|
</div> |
|
<p>Number of channels </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">338</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">n_channels</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-74'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-74'>#</a> |
|
</div> |
|
<p>For each resolution </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">340</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_resolutions</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-75'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-75'>#</a> |
|
</div> |
|
<p>Number of output channels at this resolution </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">342</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span> <span class="o">*</span> <span class="n">ch_mults</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-76'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-76'>#</a> |
|
</div> |
|
<p>Add <code class="highlight"><span></span><span class="n">n_blocks</span></code> |
|
</p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">344</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_blocks</span><span class="p">):</span> |
|
<span class="lineno">345</span> <span class="n">down</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">DownBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="n">is_attn</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> |
|
<span class="lineno">346</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">out_channels</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-77'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-77'>#</a> |
|
</div> |
|
<p>Down sample at all resolutions except the last </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">348</span> <span class="k">if</span> <span class="n">i</span> <span class="o"><</span> <span class="n">n_resolutions</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span> |
|
<span class="lineno">349</span> <span class="n">down</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Downsample</span><span class="p">(</span><span class="n">in_channels</span><span class="p">))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-78'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-78'>#</a> |
|
</div> |
|
<p>Combine the set of modules </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">352</span> <span class="bp">self</span><span class="o">.</span><span class="n">down</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">down</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-79'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-79'>#</a> |
|
</div> |
|
<p>Middle block </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">355</span> <span class="bp">self</span><span class="o">.</span><span class="n">middle</span> <span class="o">=</span> <span class="n">MiddleBlock</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-80'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-80'>#</a> |
|
</div> |
|
<h4>Second half of U-Net - increasing resolution</h4> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">358</span> <span class="n">up</span> <span class="o">=</span> <span class="p">[]</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-81'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-81'>#</a> |
|
</div> |
|
<p>Number of channels </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">360</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">out_channels</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-82'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-82'>#</a> |
|
</div> |
|
<p>For each resolution </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">362</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_resolutions</span><span class="p">)):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-83'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-83'>#</a> |
|
</div> |
|
<p><code class="highlight"><span></span><span class="n">n_blocks</span></code> |
|
at the same resolution </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">364</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span> |
|
<span class="lineno">365</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_blocks</span><span class="p">):</span> |
|
<span class="lineno">366</span> <span class="n">up</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">UpBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="n">is_attn</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-84'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-84'>#</a> |
|
</div> |
|
<p>Final block to reduce the number of channels </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">368</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span> <span class="o">//</span> <span class="n">ch_mults</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> |
|
<span class="lineno">369</span> <span class="n">up</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">UpBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="n">is_attn</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> |
|
<span class="lineno">370</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">out_channels</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-85'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-85'>#</a> |
|
</div> |
|
<p>Up sample at all resolutions except last </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">372</span> <span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> |
|
<span class="lineno">373</span> <span class="n">up</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Upsample</span><span class="p">(</span><span class="n">in_channels</span><span class="p">))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-86'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-86'>#</a> |
|
</div> |
|
<p>Combine the set of modules </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">376</span> <span class="bp">self</span><span class="o">.</span><span class="n">up</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">up</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-87'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-87'>#</a> |
|
</div> |
|
<p>Final normalization and convolution layer </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">379</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">)</span> |
|
<span class="lineno">380</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span> |
|
<span class="lineno">381</span> <span class="bp">self</span><span class="o">.</span><span class="n">final</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">image_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-88'> |
|
<div class='docs doc-strings'> |
|
<div class='section-link'> |
|
<a href='#section-88'>#</a> |
|
</div> |
|
<ul><li><code class="highlight"><span></span><span class="n">x</span></code> |
|
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code> |
|
</li> |
|
<li><code class="highlight"><span></span><span class="n">t</span></code> |
|
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">]</span></code> |
|
</li></ul> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">383</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-89'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-89'>#</a> |
|
</div> |
|
<p>Get time-step embeddings </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">390</span> <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span><span class="p">(</span><span class="n">t</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-90'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-90'>#</a> |
|
</div> |
|
<p>Get image projection </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">393</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-91'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-91'>#</a> |
|
</div> |
|
<p><code class="highlight"><span></span><span class="n">h</span></code> |
|
will store outputs at each resolution for skip connection </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">396</span> <span class="n">h</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">]</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-92'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-92'>#</a> |
|
</div> |
|
<p>First half of U-Net </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">398</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">down</span><span class="p">:</span> |
|
<span class="lineno">399</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span> |
|
<span class="lineno">400</span> <span class="n">h</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-93'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-93'>#</a> |
|
</div> |
|
<p>Middle (bottom) </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">403</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">middle</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-94'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-94'>#</a> |
|
</div> |
|
<p>Second half of U-Net </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">406</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">up</span><span class="p">:</span> |
|
<span class="lineno">407</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">Upsample</span><span class="p">):</span> |
|
<span class="lineno">408</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span> |
|
<span class="lineno">409</span> <span class="k">else</span><span class="p">:</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-95'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-95'>#</a> |
|
</div> |
|
<p>Get the skip connection from first half of U-Net and concatenate </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">411</span> <span class="n">s</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">pop</span><span class="p">()</span> |
|
<span class="lineno">412</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">s</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-96'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-96'>#</a> |
|
</div> |
|
<p> </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">414</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='section' id='section-97'> |
|
<div class='docs'> |
|
<div class='section-link'> |
|
<a href='#section-97'>#</a> |
|
</div> |
|
<p>Final normalization and convolution </p> |
|
|
|
</div> |
|
<div class='code'> |
|
<div class="highlight"><pre><span class="lineno">417</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">final</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div> |
|
</div> |
|
</div> |
|
<div class='footer'> |
|
<a href="https://papers.labml.ai">Trending Research Papers</a> |
|
<a href="https://labml.ai">labml.ai</a> |
|
</div> |
|
</div> |
|
<script src=../../interactive.js?v=1"></script> |
|
<script> |
|
function handleImages() { |
|
var images = document.querySelectorAll('p>img') |
|
|
|
for (var i = 0; i < images.length; ++i) { |
|
handleImage(images[i]) |
|
} |
|
} |
|
|
|
function handleImage(img) { |
|
img.parentElement.style.textAlign = 'center' |
|
|
|
var modal = document.createElement('div') |
|
modal.id = 'modal' |
|
|
|
var modalContent = document.createElement('div') |
|
modal.appendChild(modalContent) |
|
|
|
var modalImage = document.createElement('img') |
|
modalContent.appendChild(modalImage) |
|
|
|
var span = document.createElement('span') |
|
span.classList.add('close') |
|
span.textContent = 'x' |
|
modal.appendChild(span) |
|
|
|
img.onclick = function () { |
|
console.log('clicked') |
|
document.body.appendChild(modal) |
|
modalImage.src = img.src |
|
} |
|
|
|
span.onclick = function () { |
|
document.body.removeChild(modal) |
|
} |
|
} |
|
|
|
handleImages() |
|
</script> |
|
</body> |
|
</html> |