ChiKyi commited on
Commit
cf2db44
·
1 Parent(s): 868b5e1
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ from matplotlib import pyplot as plt
5
+ import gradio as gr
6
+
7
+ from models import MainModel # Import class for your main model
8
+ from utils import lab_to_rgb, build_res_unet#, build_mobile_unet # Utility to convert LAB to RGB
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+
13
+ def load_model(generator_model_path, colorization_model_path): #, model_type='resnet')
14
+
15
+ #if model_type == 'resnet':
16
+ net_G = build_res_unet(n_input=1, n_output=2, size=256)
17
+ # elif model_type == 'mobilenet':
18
+ # net_G = build_mobile_unet(n_input=1, n_output=2, size=256)
19
+
20
+ net_G.load_state_dict(torch.load(generator_model_path, map_location=device))
21
+
22
+ # Create MainModel and load weights
23
+ model = MainModel(net_G=net_G)
24
+ model.load_state_dict(torch.load(colorization_model_path, map_location=device))
25
+
26
+ # Move model to device and set to eval mode
27
+ model.to(device)
28
+ model.eval()
29
+
30
+ return model
31
+
32
+ # Load pretrained models
33
+ resnet_model = load_model(
34
+ "weight/pascal_res18-unet.pt",
35
+ "weight/pascal_final_model_weights.pt"
36
+ # model_type='resnet'
37
+ )
38
+
39
+ # mobilenet_model = load_model(
40
+ # "weight/mobile-unet.pt",
41
+ # "weight/mobile_pascal_final_model_weights.pt",
42
+ # model_type='mobilenet'
43
+ # )
44
+
45
+ # Transformations
46
+ def preprocess_image(image):
47
+ image = image.resize((256, 256))
48
+ image = transforms.ToTensor()(image)[:1] * 2. - 1. # Normalize to [-1, 1]
49
+ return image
50
+
51
+ def postprocess_image(grayscale, prediction):
52
+ return lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
53
+
54
+ # Prediction function
55
+ def colorize_image(input_image):
56
+ # Convert input to grayscale
57
+ input_image = Image.fromarray(input_image).convert('L')
58
+ grayscale = preprocess_image(input_image).to(device)
59
+
60
+ # Generate predictions
61
+ with torch.no_grad():
62
+ resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
63
+ # mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
64
+
65
+ # Post-process results
66
+ resnet_colorized = postprocess_image(grayscale, resnet_output)
67
+ # mobilenet_colorized = postprocess_image(grayscale, mobilenet_output)
68
+
69
+ return (
70
+ input_image, # Grayscale image
71
+ resnet_colorized # ResNet18 colorized image
72
+ # mobilenet_colorized # MobileNet colorized image
73
+ )
74
+
75
+ # Gradio Interface
76
+ interface = gr.Interface(
77
+ fn=colorize_image,
78
+ inputs=gr.Image(type="numpy", label="Upload a Color Image"),
79
+ outputs=[
80
+ gr.Image(label="Grayscale Image"),
81
+ gr.Image(label="Colorized Image (ResNet18)")
82
+ # gr.Image(label="Colorized Image (MobileNet)")
83
+ ],
84
+ title="Image Colorization",
85
+ description="Upload a color image"
86
+ )
87
+
88
+ # Launch Gradio app
89
+ if __name__ == '__main__':
90
+ interface.launch()