AI_Text_to_Image / maindata
dghdgkl's picture
Create maindata
6a3f4a1 verified
pip install torch diffusers transformers datasets wandb
import torch
import torch.nn as nn
from torch.nn import functional as F
# Define a basic U-Net style model (you can scale this up for an XL model)
class UNetModel(nn.Module):
def __init__(self, in_channels=3, out_channels=3, base_channels=64):
super(UNetModel, self).__init__()
# Downsample
self.enc1 = self.conv_block(in_channels, base_channels)
self.enc2 = self.conv_block(base_channels, base_channels * 2)
self.enc3 = self.conv_block(base_channels * 2, base_channels * 4)
# Middle
self.middle = self.conv_block(base_channels * 4, base_channels * 8)
# Upsample
self.dec3 = self.conv_block(base_channels * 8, base_channels * 4)
self.dec2 = self.conv_block(base_channels * 4, base_channels * 2)
self.dec1 = self.conv_block(base_channels * 2, out_channels)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
def forward(self, x):
# Encode (Downsample)
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
# Middle block
x_middle = self.middle(x3)
# Decode (Upsample)
x3_dec = self.dec3(x_middle)
x2_dec = self.dec2(x3_dec + x3)
x1_dec = self.dec1(x2_dec + x2)
return x1_dec