File size: 6,325 Bytes
6817d94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "ClbDA89uqYc2",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 383
        },
        "outputId": "0dac9131-1d0c-416b-98fb-9dcc7e0dcf5b"
      },
      "outputs": [
        {
          "output_type": "error",
          "ename": "ModuleNotFoundError",
          "evalue": "No module named 'attention'",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-8-58da732233b1>\u001b[0m in \u001b[0;36m<cell line: 4>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mfunctional\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mattention\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSelfAttention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mVAE_AttentionBlock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'attention'",
            "",
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
          ],
          "errorDetails": {
            "actions": [
              {
                "action": "open_url",
                "actionText": "Open Examples",
                "url": "/notebooks/snippets/importing_libraries.ipynb"
              }
            ]
          }
        }
      ],
      "source": [
        "import torch\n",
        "from torch import nn\n",
        "from torch.nn import functional as F\n",
        "from attention import SelfAttention\n",
        "\n",
        "class VAE_AttentionBlock(nn.Module):\n",
        "  def __init__(self, channels):\n",
        "    super.__init__()\n",
        "    self.groupnorm = nn.GroupNorm(32, channels)\n",
        "    self.attention = Attention(1, channels)\n",
        "  def forward(self, x):\n",
        "    residue=x\n",
        "    x=self.groupnorm(x)\n",
        "    n,c,h,w =x.shape\n",
        "    x=x.view(n,c,h*w)\n",
        "    x=x.transpose(-1,-2)\n",
        "    x=self.attention(x)\n",
        "    x=x.view((n,c,h,w))\n",
        "    x+=residue\n",
        "    return x\n",
        "class VAE_ResidualBlock(nn.Module):\n",
        "  def __init__(self, in_channels, out_channels):\n",
        "    super.__init__()\n",
        "    self.groupnorm_1= nn.GroupNorm(32, in_channels)\n",
        "    self.conv_1=nn.Conv2d(in_channels, out_channels , kernel_size=3, padding=1)\n",
        "\n",
        "    self.groupnorm_2= nn.GroupNorm(32, in_channels)\n",
        "    self.conv_2= nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)\n",
        "\n",
        "    if in_channels==out_channels:\n",
        "      self.residual_layer=nn.Identity()\n",
        "    else:\n",
        "      self.conv_2= nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)\n",
        "    def forward(self, x):\n",
        "      residue=x\n",
        "      x=self.groupnorm_1(x)\n",
        "      x=F.silu(x)\n",
        "      x=self.conv_2(x)\n",
        "      return x+self.residual_layer(residue)\n",
        "class VAE_Decoder(nn.Sequential):\n",
        "  def __init__(self):\n",
        "    super.__init__(\n",
        "        nn.Conv2d(4,4, kernel_size=1, padding=0),\n",
        "        nn.Conv2d(4, 512, kernel_size=3, padding=1),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "        VAE_AttentionBlock(512),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "\n",
        "        nn.Upsample(scale_factor=2),\n",
        "        nn.Conv2d(512, 512,kernel_size=3, padding=1),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "        nn.Upsample(scale_factor=2),\n",
        "        VAE_ResidualBlock(512, 256),\n",
        "        VAE_ResidualBlock(256, 256),\n",
        "        VAE_ResidualBlock(256, 256),\n",
        "        nn.Upsample(scale_factor=2),\n",
        "        nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
        "        VAE_ResidualBlock(256, 128),\n",
        "        VAE_ResidualBlock(128, 128),\n",
        "        VAE_ResidualBlock(128, 128),\n",
        "        nn.GroupNorm(32, 128),\n",
        "        nn.SiLU(),\n",
        "        nn.Conv2d(128, 3, kernel_size=3,padding=1)\n",
        "    )\n",
        "  def forward(self, x):\n",
        "    x/=0.18125\n",
        "    for module in self:\n",
        "      x=module(x)\n",
        "    return x\n",
        "\n",
        "\n"
      ]
    }
  ]
}