File size: 1,503 Bytes
6a3f4a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
pip install torch diffusers transformers datasets wandb
import torch
import torch.nn as nn
from torch.nn import functional as F

# Define a basic U-Net style model (you can scale this up for an XL model)
class UNetModel(nn.Module):
 def __init__(self, in_channels=3, out_channels=3, base_channels=64):
     super(UNetModel, self).__init__()
     
     # Downsample
     self.enc1 = self.conv_block(in_channels, base_channels)
     self.enc2 = self.conv_block(base_channels, base_channels * 2)
     self.enc3 = self.conv_block(base_channels * 2, base_channels * 4)

     # Middle
     self.middle = self.conv_block(base_channels * 4, base_channels * 8)

     # Upsample
     self.dec3 = self.conv_block(base_channels * 8, base_channels * 4)
     self.dec2 = self.conv_block(base_channels * 4, base_channels * 2)
     self.dec1 = self.conv_block(base_channels * 2, out_channels)

 def conv_block(self, in_channels, out_channels):
     return nn.Sequential(
         nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
         nn.ReLU(),
         nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
         nn.ReLU(),
         nn.MaxPool2d(2)
     )

 def forward(self, x):
     # Encode (Downsample)
     x1 = self.enc1(x)
     x2 = self.enc2(x1)
     x3 = self.enc3(x2)

     # Middle block
     x_middle = self.middle(x3)

     # Decode (Upsample)
     x3_dec = self.dec3(x_middle)
     x2_dec = self.dec2(x3_dec + x3)
     x1_dec = self.dec1(x2_dec + x2)

     return x1_dec