remotewith's picture
Upload 3 files
ce4c34e verified
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image
def double_convolution(in_channels, out_channels):
conv_op = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
return conv_op
class UNet(nn.Module):
def __init__(self, in_channels,out_channels):
super(UNet, self).__init__()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.down_convolution_1 = double_convolution(in_channels, 64)
self.down_convolution_2 = double_convolution(64, 128)
self.down_convolution_3 = double_convolution(128, 256)
self.down_convolution_4 = double_convolution(256, 512)
self.down_convolution_5 = double_convolution(512, 1024)
self.up_transpose_1 = nn.ConvTranspose2d(
in_channels=1024, out_channels=512,
kernel_size=2,
stride=2)
self.up_convolution_1 = double_convolution(1024, 512)
self.up_transpose_2 = nn.ConvTranspose2d(
in_channels=512, out_channels=256,
kernel_size=2,
stride=2)
self.up_convolution_2 = double_convolution(512, 256)
self.up_transpose_3 = nn.ConvTranspose2d(
in_channels=256, out_channels=128,
kernel_size=2,
stride=2)
self.up_convolution_3 = double_convolution(256, 128)
self.up_transpose_4 = nn.ConvTranspose2d(
in_channels=128, out_channels=64,
kernel_size=2,
stride=2)
self.up_convolution_4 = double_convolution(128, 64)
self.out = nn.Conv2d(
in_channels=64, out_channels=out_channels,
kernel_size=1
)
def forward(self, x):
down_1 = self.down_convolution_1(x)
down_2 = self.max_pool2d(down_1)
down_3 = self.down_convolution_2(down_2)
down_4 = self.max_pool2d(down_3)
down_5 = self.down_convolution_3(down_4)
down_6 = self.max_pool2d(down_5)
down_7 = self.down_convolution_4(down_6)
down_8 = self.max_pool2d(down_7)
down_9 = self.down_convolution_5(down_8)
up_1 = self.up_transpose_1(down_9)
x = self.up_convolution_1(torch.cat([down_7, up_1], 1))
up_2 = self.up_transpose_2(x)
x = self.up_convolution_2(torch.cat([down_5, up_2], 1))
up_3 = self.up_transpose_3(x)
x = self.up_convolution_3(torch.cat([down_3, up_3], 1))
up_4 = self.up_transpose_4(x)
x = self.up_convolution_4(torch.cat([down_1, up_4], 1))
out = self.out(x)
return out
class CustomDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.image_filenames = os.listdir(image_dir)
self.mask_filenames = os.listdir(mask_dir)
def __len__(self):
return len(self.image_filenames)
def __getitem__(self , idx):
image_path = os.path.join(self.image_dir, self.image_filenames[idx])
mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
image = Image.open(image_path).convert("RGB")
mask = Image.open(mask_path).convert("L")
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image,mask
def train_model(model, dataloader, criterion, optimizer, num_epochs=25):
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images,masks in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss +=loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}')
if __name__ == "__main__":
transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor(),
])
image_dir = "face-synthetics-glasses/train/images"
mask_dir = "face-synthetics-glasses/train/masks"
dataset = CustomDataset(image_dir , mask_dir ,transform=transform)
dataloader = DataLoader(dataset,batch_size=2,shuffle=True)
model = UNet(3,1)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)
print("moving ahead")
# train_model(model,dataloader,criterion,optimizer,num_epochs=25)
# torch.save(model.state_dict(),"base_bat_ball.pth")