remotewith's picture
Upload 3 files
ce4c34e verified
raw
history blame
4.88 kB
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image
def double_convolution(in_channels, out_channels):
conv_op = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
return conv_op
class UNet(nn.Module):
def __init__(self, in_channels,out_channels):
super(UNet, self).__init__()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.down_convolution_1 = double_convolution(in_channels, 64)
self.down_convolution_2 = double_convolution(64, 128)
self.down_convolution_3 = double_convolution(128, 256)
self.down_convolution_4 = double_convolution(256, 512)
self.down_convolution_5 = double_convolution(512, 1024)
self.up_transpose_1 = nn.ConvTranspose2d(
in_channels=1024, out_channels=512,
kernel_size=2,
stride=2)
self.up_convolution_1 = double_convolution(1024, 512)
self.up_transpose_2 = nn.ConvTranspose2d(
in_channels=512, out_channels=256,
kernel_size=2,
stride=2)
self.up_convolution_2 = double_convolution(512, 256)
self.up_transpose_3 = nn.ConvTranspose2d(
in_channels=256, out_channels=128,
kernel_size=2,
stride=2)
self.up_convolution_3 = double_convolution(256, 128)
self.up_transpose_4 = nn.ConvTranspose2d(
in_channels=128, out_channels=64,
kernel_size=2,
stride=2)
self.up_convolution_4 = double_convolution(128, 64)
self.out = nn.Conv2d(
in_channels=64, out_channels=out_channels,
kernel_size=1
)
def forward(self, x):
down_1 = self.down_convolution_1(x)
down_2 = self.max_pool2d(down_1)
down_3 = self.down_convolution_2(down_2)
down_4 = self.max_pool2d(down_3)
down_5 = self.down_convolution_3(down_4)
down_6 = self.max_pool2d(down_5)
down_7 = self.down_convolution_4(down_6)
down_8 = self.max_pool2d(down_7)
down_9 = self.down_convolution_5(down_8)
up_1 = self.up_transpose_1(down_9)
x = self.up_convolution_1(torch.cat([down_7, up_1], 1))
up_2 = self.up_transpose_2(x)
x = self.up_convolution_2(torch.cat([down_5, up_2], 1))
up_3 = self.up_transpose_3(x)
x = self.up_convolution_3(torch.cat([down_3, up_3], 1))
up_4 = self.up_transpose_4(x)
x = self.up_convolution_4(torch.cat([down_1, up_4], 1))
out = self.out(x)
return out
class CustomDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.image_filenames = os.listdir(image_dir)
self.mask_filenames = os.listdir(mask_dir)
def __len__(self):
return len(self.image_filenames)
def __getitem__(self , idx):
image_path = os.path.join(self.image_dir, self.image_filenames[idx])
mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
image = Image.open(image_path).convert("RGB")
mask = Image.open(mask_path).convert("L")
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image,mask
def train_model(model, dataloader, criterion, optimizer, num_epochs=25):
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images,masks in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss +=loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}')
if __name__ == "__main__":
transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor(),
])
image_dir = "face-synthetics-glasses/train/images"
mask_dir = "face-synthetics-glasses/train/masks"
dataset = CustomDataset(image_dir , mask_dir ,transform=transform)
dataloader = DataLoader(dataset,batch_size=2,shuffle=True)
model = UNet(3,1)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)
print("moving ahead")
# train_model(model,dataloader,criterion,optimizer,num_epochs=25)
# torch.save(model.state_dict(),"base_bat_ball.pth")