Spaces:
Runtime error
Runtime error
File size: 3,327 Bytes
9205986 03856d4 9205986 03856d4 9205986 03856d4 9205986 03856d4 9205986 03856d4 9205986 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
image_transforms_rgb = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.0,0.0,0.0], std=[1.0,1.0,1.0]),
torchvision.transforms.Grayscale()
])
image_transforms_gs = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.0], std=[1.0]),
])
class ConvBlock(nn.Module):
def __init__(self, in_channel, out_channel):
super(ConvBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(True),
nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(True)
)
def forward(self, x):
return self.main(x)
class UNETFruitColor(nn.Module):
def __init__(self):
super(UNETFruitColor, self).__init__()
self.convs = [64, 128, 256, 512]
self.convEncoder = nn.ModuleList()
in_feature = 1
for conv in self.convs:
self.convEncoder.append(ConvBlock(in_feature, conv))
in_feature = conv
self.bottleNeck = ConvBlock(self.convs[-1], self.convs[-1]*2)
in_feature = self.convs[-1]*2
self.convDecoder = nn.ModuleList()
self.decoderUpConvs = nn.ModuleList()
for conv in self.convs[::-1]:
self.convDecoder.append(ConvBlock(in_feature, conv))
self.decoderUpConvs.append(nn.ConvTranspose2d(in_feature, conv, kernel_size=2, stride=2, padding=0))
in_feature = conv
# final conv and deconv
self.finalUpConv = nn.Conv2d(in_feature, 3, (1, 1))
self.sigmoid = nn.Sigmoid()
def forward(self,x):
skip_conns = []
for conv in self.convEncoder:
# conv
x = conv(x)
# append for skip conns
skip_conns.append(x)
# max pool
x = F.max_pool2d(x, (2,2), stride=2)
x = self.bottleNeck(x)
skip_conns = skip_conns[::-1]
for idx in range(len(self.convDecoder)):
# do upsample here
upconv = self.decoderUpConvs[idx]
deconv = self.convDecoder[idx]
skp = skip_conns[idx]
# do up conv
x = upconv(x)
# crop and cat
x_cat = torchvision.transforms.Resize((x.shape[2], x.shape[3]))(skp)
x = torch.cat([x_cat, x], dim=1)
# do deconv
x = deconv(x)
# final
x = self.finalUpConv(x)
# x = self.sigmoid(x)
return x
model = UNETFruitColor()
model = nn.DataParallel(model).to(device)
model.load_state_dict(torch.load("unet_colorizer_flickr_5_93_Ploss_10_14K.pth", map_location=device),strict=True)
model.eval() |