Heisenberg08 commited on
Commit
fe70fd4
1 Parent(s): 722e7f6

added model and code

Browse files
Files changed (6) hide show
  1. Unet_acc_94.pth +3 -0
  2. app.py +52 -0
  3. dataset.py +36 -0
  4. model.py +68 -0
  5. predict.py +38 -0
  6. tempCodeRunnerFile.py +38 -0
Unet_acc_94.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ddd029a071966365f75ea31ac19ce929dc888c9df801dcd6831888fd15444ec
3
+ size 124247189
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchvision.transforms.functional as TF
11
+ from torchvision import transforms
12
+
13
+ from model import DoubleConv,UNET
14
+
15
+ convert_tensor = transforms.ToTensor()
16
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model = UNET(in_channels=3, out_channels=1).to(device)
18
+ # model=torch.load("Unet_acc_94.pth",map_location=torch.device('cpu'))
19
+
20
+ model=torch.load("Unet_acc_94.pth",map_location=device)
21
+
22
+ def predict(img):
23
+ img=cv2.resize(img,(240,160))
24
+ test_img=convert_tensor(img).unsqueeze(0)
25
+ print(test_img.shape)
26
+ preds=model(test_img.float())
27
+ preds=torch.sigmoid(preds)
28
+ preds=(preds > 0.5).float()
29
+ print(preds.shape)
30
+ im=preds.squeeze(0).permute(1,2,0).detach()
31
+ print(im.shape)
32
+ im=im.numpy()
33
+ return im
34
+
35
+ import streamlit as st
36
+ st.title("Image Colorizer")
37
+
38
+ file=st.file_uploader("Please upload the B/W image",type=["jpg","jpeg","png"])
39
+ print(file)
40
+ if file is None:
41
+ st.text("Please Upload an image")
42
+ else:
43
+ file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
44
+ opencv_image = cv2.imdecode(file_bytes, 1)
45
+ im=predict(opencv_image)
46
+ st.text("Original")
47
+ st.image(file)
48
+ st.text("Colorized!!")
49
+ st.image(im)
50
+
51
+
52
+
dataset.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data.dataloader import DataLoader,Dataset
3
+ import torch.optim as optim
4
+ import albumentations as A
5
+ from albumentations.pytorch import ToTensorV2
6
+
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import os
10
+ from PIL import Image
11
+
12
+ class Segmentation_Dataset(Dataset):
13
+ def __init__(self,img_dir,mask_dir,transform=None):
14
+ self.img_dir=img_dir
15
+ self.mask_dir=mask_dir
16
+ self.transform=transform
17
+ self.images=os.listdir(img_dir)
18
+ self.images=[im for im in self.images if ".jpg" in im]
19
+ def __len__(self):
20
+ return len(self.images)
21
+
22
+ def __getitem__(self,idx):
23
+ img_path=os.path.join(self.img_dir,self.images[idx])
24
+ mask_path=os.path.join(self.mask_dir,self.images[idx].replace(".jpg",".png"))
25
+
26
+ image=np.array(Image.open(img_path).convert("RGB"))
27
+ mask=np.array(Image.open(mask_path).convert("L"),dtype=np.float32)
28
+ mask[mask==255]=1.0
29
+
30
+ if self.transform is not None:
31
+ augmentations=self.transform(image=image,mask=mask)
32
+ image=augmentations["image"]
33
+ mask=augmentations["mask"]
34
+
35
+ return image, mask
36
+
model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms.functional as TF
4
+
5
+ class DoubleConv(nn.Module):
6
+ def __init__(self, in_channels, out_channels):
7
+ super(DoubleConv, self).__init__()
8
+ self.conv = nn.Sequential(
9
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
10
+ nn.BatchNorm2d(out_channels),
11
+ nn.ReLU(inplace=True),
12
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
13
+ nn.BatchNorm2d(out_channels),
14
+ nn.ReLU(inplace=True),
15
+ )
16
+
17
+ def forward(self, x):
18
+ return self.conv(x)
19
+
20
+ class UNET(nn.Module):
21
+ def __init__(
22
+ self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
23
+ ):
24
+ super(UNET, self).__init__()
25
+ self.ups = nn.ModuleList()
26
+ self.downs = nn.ModuleList()
27
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
28
+
29
+ # Down part of UNET
30
+ for feature in features:
31
+ self.downs.append(DoubleConv(in_channels, feature))
32
+ in_channels = feature
33
+
34
+ # Up part of UNET
35
+ for feature in reversed(features):
36
+ self.ups.append(
37
+ nn.ConvTranspose2d(
38
+ feature*2, feature, kernel_size=2, stride=2,
39
+ )
40
+ )
41
+ self.ups.append(DoubleConv(feature*2, feature))
42
+
43
+ self.bottleneck = DoubleConv(features[-1], features[-1]*2)
44
+ self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
45
+
46
+ def forward(self, x):
47
+ skip_connections = []
48
+
49
+ for down in self.downs:
50
+ x = down(x)
51
+ skip_connections.append(x)
52
+ x = self.pool(x)
53
+
54
+ x = self.bottleneck(x)
55
+ skip_connections = skip_connections[::-1]
56
+
57
+ for idx in range(0, len(self.ups), 2):
58
+ x = self.ups[idx](x)
59
+ skip_connection = skip_connections[idx//2]
60
+
61
+ if x.shape != skip_connection.shape:
62
+ x = TF.resize(x, size=skip_connection.shape[2:])
63
+
64
+ concat_skip = torch.cat((skip_connection, x), dim=1)
65
+ x = self.ups[idx+1](concat_skip)
66
+
67
+ return self.final_conv(x)
68
+
predict.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ from model import DoubleConv,UNET
9
+
10
+ import os
11
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
12
+
13
+
14
+ convert_tensor = transforms.ToTensor()
15
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ # print(device)
17
+
18
+ model = UNET(in_channels=3, out_channels=1).to(device)
19
+ model=torch.load("Unet_acc_94.pth",map_location=torch.device('cpu'))
20
+
21
+ # test_img=np.array(Image.open("profilepic - Copy.jpeg").resize((160,240)))
22
+ test_img=Image.open("104.jpg").resize((240,160))
23
+
24
+ # test_img=torch.tensor(test_img).permute(2,1,0)
25
+ # test_img=test_img.unsqueeze(0)
26
+ test_img=convert_tensor(test_img).unsqueeze(0)
27
+ print(test_img.shape)
28
+ preds=model(test_img.float())
29
+ preds=torch.sigmoid(preds)
30
+ preds=(preds > 0.5).float()
31
+ print(preds.shape)
32
+ im=preds.squeeze(0).permute(1,2,0).detach()
33
+ print(im.shape)
34
+ fig,axs=plt.subplots(1,2)
35
+
36
+ axs[0].imshow(im)
37
+ axs[1].imshow(test_img.squeeze(0).permute(1,2,0).detach())
38
+ plt.show()
tempCodeRunnerFile.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ from model import DoubleConv,UNET
9
+
10
+ import os
11
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
12
+
13
+
14
+ convert_tensor = transforms.ToTensor()
15
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ # print(device)
17
+
18
+ model = UNET(in_channels=3, out_channels=1).to(device)
19
+ model=torch.load("Unet_acc_94.pth",map_location=torch.device('cpu'))
20
+
21
+ # test_img=np.array(Image.open("profilepic - Copy.jpeg").resize((160,240)))
22
+ test_img=Image.open("104.jpg").resize((240,160))
23
+
24
+ # test_img=torch.tensor(test_img).permute(2,1,0)
25
+ # test_img=test_img.unsqueeze(0)
26
+ test_img=convert_tensor(test_img).unsqueeze(0)
27
+ print(test_img.shape)
28
+ preds=model(test_img.float())
29
+ preds=torch.sigmoid(preds)
30
+ preds=(preds > 0.5).float()
31
+ print(preds.shape)
32
+ im=preds.squeeze(0).permute(1,2,0).detach()
33
+ print(im.shape)
34
+ fig,axs=plt.subplots(1,2)
35
+
36
+ axs[0].imshow(im)
37
+ axs[1].imshow(test_img.squeeze(0).permute(1,2,0).detach())
38
+ plt.show()