Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import transforms | |
from torch.utils.data import DataLoader,Dataset | |
from PIL import Image | |
def double_convolution(in_channels, out_channels): | |
conv_op = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
return conv_op | |
class UNet(nn.Module): | |
def __init__(self, in_channels,out_channels): | |
super(UNet, self).__init__() | |
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.down_convolution_1 = double_convolution(in_channels, 64) | |
self.down_convolution_2 = double_convolution(64, 128) | |
self.down_convolution_3 = double_convolution(128, 256) | |
self.down_convolution_4 = double_convolution(256, 512) | |
self.down_convolution_5 = double_convolution(512, 1024) | |
self.up_transpose_1 = nn.ConvTranspose2d( | |
in_channels=1024, out_channels=512, | |
kernel_size=2, | |
stride=2) | |
self.up_convolution_1 = double_convolution(1024, 512) | |
self.up_transpose_2 = nn.ConvTranspose2d( | |
in_channels=512, out_channels=256, | |
kernel_size=2, | |
stride=2) | |
self.up_convolution_2 = double_convolution(512, 256) | |
self.up_transpose_3 = nn.ConvTranspose2d( | |
in_channels=256, out_channels=128, | |
kernel_size=2, | |
stride=2) | |
self.up_convolution_3 = double_convolution(256, 128) | |
self.up_transpose_4 = nn.ConvTranspose2d( | |
in_channels=128, out_channels=64, | |
kernel_size=2, | |
stride=2) | |
self.up_convolution_4 = double_convolution(128, 64) | |
self.out = nn.Conv2d( | |
in_channels=64, out_channels=out_channels, | |
kernel_size=1 | |
) | |
def forward(self, x): | |
down_1 = self.down_convolution_1(x) | |
down_2 = self.max_pool2d(down_1) | |
down_3 = self.down_convolution_2(down_2) | |
down_4 = self.max_pool2d(down_3) | |
down_5 = self.down_convolution_3(down_4) | |
down_6 = self.max_pool2d(down_5) | |
down_7 = self.down_convolution_4(down_6) | |
down_8 = self.max_pool2d(down_7) | |
down_9 = self.down_convolution_5(down_8) | |
up_1 = self.up_transpose_1(down_9) | |
x = self.up_convolution_1(torch.cat([down_7, up_1], 1)) | |
up_2 = self.up_transpose_2(x) | |
x = self.up_convolution_2(torch.cat([down_5, up_2], 1)) | |
up_3 = self.up_transpose_3(x) | |
x = self.up_convolution_3(torch.cat([down_3, up_3], 1)) | |
up_4 = self.up_transpose_4(x) | |
x = self.up_convolution_4(torch.cat([down_1, up_4], 1)) | |
out = self.out(x) | |
return out | |
class CustomDataset(Dataset): | |
def __init__(self, image_dir, mask_dir, transform=None): | |
self.image_dir = image_dir | |
self.mask_dir = mask_dir | |
self.transform = transform | |
self.image_filenames = os.listdir(image_dir) | |
self.mask_filenames = os.listdir(mask_dir) | |
def __len__(self): | |
return len(self.image_filenames) | |
def __getitem__(self , idx): | |
image_path = os.path.join(self.image_dir, self.image_filenames[idx]) | |
mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx]) | |
image = Image.open(image_path).convert("RGB") | |
mask = Image.open(mask_path).convert("L") | |
if self.transform: | |
image = self.transform(image) | |
mask = self.transform(mask) | |
return image,mask | |
def train_model(model, dataloader, criterion, optimizer, num_epochs=25): | |
for epoch in range(num_epochs): | |
model.train() | |
running_loss = 0.0 | |
for images,masks in dataloader: | |
optimizer.zero_grad() | |
outputs = model(images) | |
loss = criterion(outputs, masks) | |
loss.backward() | |
optimizer.step() | |
running_loss +=loss.item() | |
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}') | |
if __name__ == "__main__": | |
transform = transforms.Compose([ | |
transforms.Resize((256,256)), | |
transforms.ToTensor(), | |
]) | |
image_dir = "face-synthetics-glasses/train/images" | |
mask_dir = "face-synthetics-glasses/train/masks" | |
dataset = CustomDataset(image_dir , mask_dir ,transform=transform) | |
dataloader = DataLoader(dataset,batch_size=2,shuffle=True) | |
model = UNet(3,1) | |
criterion = nn.BCEWithLogitsLoss() | |
optimizer = optim.Adam(model.parameters(),lr=0.001) | |
print("moving ahead") | |
# train_model(model,dataloader,criterion,optimizer,num_epochs=25) | |
# torch.save(model.state_dict(),"base_bat_ball.pth") | |