akhaliq HF Staff commited on
Commit
3d887bb
·
1 Parent(s): 0aec812

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mxnet as mx
2
+ import cv2 as cv
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+ import math
7
+ from collections import namedtuple
8
+ from mxnet.contrib.onnx import import_model
9
+ import cityscapes_labels
10
+ import gradio as gr
11
+
12
+ def preprocess(im):
13
+ # Convert to float32
14
+ test_img = im.astype(np.float32)
15
+ # Extrapolate image with a small border in order obtain an accurate reshaped image after DUC layer
16
+ test_shape = [im.shape[0],im.shape[1]]
17
+ cell_shapes = [math.ceil(l / 8)*8 for l in test_shape]
18
+ test_img = cv.copyMakeBorder(test_img, 0, max(0, int(cell_shapes[0]) - im.shape[0]), 0, max(0, int(cell_shapes[1]) - im.shape[1]), cv.BORDER_CONSTANT, value=rgb_mean)
19
+ test_img = np.transpose(test_img, (2, 0, 1))
20
+ # subtract rbg mean
21
+ for i in range(3):
22
+ test_img[i] -= rgb_mean[i]
23
+ test_img = np.expand_dims(test_img, axis=0)
24
+ # convert to ndarray
25
+ test_img = mx.ndarray.array(test_img)
26
+ return test_img
27
+
28
+ def get_palette():
29
+ # get train id to color mappings from file
30
+ trainId2colors = {label.trainId: label.color for label in cityscapes_labels.labels}
31
+ # prepare and return palette
32
+ palette = [0] * 256 * 3
33
+ for trainId in trainId2colors:
34
+ colors = trainId2colors[trainId]
35
+ if trainId == 255:
36
+ colors = (0, 0, 0)
37
+ for i in range(3):
38
+ palette[trainId * 3 + i] = colors[i]
39
+ return palette
40
+
41
+ def colorize(labels):
42
+ # generate colorized image from output labels and color palette
43
+ result_img = Image.fromarray(labels).convert('P')
44
+ result_img.putpalette(get_palette())
45
+ return np.array(result_img.convert('RGB'))
46
+
47
+ def predict(imgs):
48
+ # get input and output dimensions
49
+ result_height, result_width = result_shape
50
+ _, _, img_height, img_width = imgs.shape
51
+ # set downsampling rate
52
+ ds_rate = 8
53
+ # set cell width
54
+ cell_width = 2
55
+ # number of output label classes
56
+ label_num = 19
57
+
58
+ # Perform forward pass
59
+ batch = namedtuple('Batch', ['data'])
60
+ mod.forward(batch([imgs]),is_train=False)
61
+ labels = mod.get_outputs()[0].asnumpy().squeeze()
62
+
63
+ # re-arrange output
64
+ test_width = int((int(img_width) / ds_rate) * ds_rate)
65
+ test_height = int((int(img_height) / ds_rate) * ds_rate)
66
+ feat_width = int(test_width / ds_rate)
67
+ feat_height = int(test_height / ds_rate)
68
+ labels = labels.reshape((label_num, 4, 4, feat_height, feat_width))
69
+ labels = np.transpose(labels, (0, 3, 1, 4, 2))
70
+ labels = labels.reshape((label_num, int(test_height / cell_width), int(test_width / cell_width)))
71
+
72
+ labels = labels[:, :int(img_height / cell_width),:int(img_width / cell_width)]
73
+ labels = np.transpose(labels, [1, 2, 0])
74
+ labels = cv.resize(labels, (result_width, result_height), interpolation=cv.INTER_LINEAR)
75
+ labels = np.transpose(labels, [2, 0, 1])
76
+
77
+ # get softmax output
78
+ softmax = labels
79
+
80
+ # get classification labels
81
+ results = np.argmax(labels, axis=0).astype(np.uint8)
82
+ raw_labels = results
83
+
84
+ # comput confidence score
85
+ confidence = float(np.max(softmax, axis=0).mean())
86
+
87
+ # generate segmented image
88
+ result_img = Image.fromarray(colorize(raw_labels)).resize(result_shape[::-1])
89
+
90
+ # generate blended image
91
+ blended_img = Image.fromarray(cv.addWeighted(im[:, :, ::-1], 0.5, np.array(result_img), 0.5, 0))
92
+
93
+ return confidence, result_img, blended_img, raw_labels
94
+
95
+ def get_model(ctx, model_path):
96
+ # import ONNX model into MXNet symbols and params
97
+ sym,arg,aux = import_model(model_path)
98
+ # define network module
99
+ mod = mx.mod.Module(symbol=sym, data_names=['data'], context=ctx, label_names=None)
100
+ # bind parameters to the network
101
+ mod.bind(for_training=False, data_shapes=[('data', (1, 3, im.shape[0], im.shape[1]))], label_shapes=mod._label_shapes)
102
+ mod.set_params(arg_params=arg, aux_params=aux,allow_missing=True, allow_extra=True)
103
+ return mod
104
+
105
+ # Download test image
106
+ mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/duc/city1.png')
107
+ # read image as rgb
108
+ im = cv.imread('city1.png')[:, :, ::-1]
109
+ # set output shape (same as input shape)
110
+ result_shape = [im.shape[0],im.shape[1]]
111
+ # set rgb mean of input image (used in mean subtraction)
112
+ rgb_mean = cv.mean(im)
113
+
114
+
115
+ # Download ONNX model
116
+ mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/duc/ResNet101_DUC_HDC.onnx')
117
+
118
+ # Determine and set context
119
+ if len(mx.test_utils.list_gpus())==0:
120
+ ctx = mx.cpu()
121
+ else:
122
+ ctx = mx.gpu(0)
123
+
124
+ # Load ONNX model
125
+ mod = get_model(ctx, 'ResNet101_DUC_HDC.onnx')
126
+
127
+ def inference(im):
128
+ pre = preprocess(im)
129
+ conf,result_img,blended_img,raw = predict(pre)
130
+ return blended_img
131
+
132
+ gr.Interface(inference,"image",gr.outputs.Image(type="pil")).launch()