CVRPDataset commited on
Commit
aaaf972
·
verified ·
1 Parent(s): 9d29746

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from mmseg.apis import init_model, inference_model
5
+ import torch
6
+
7
+
8
+ def process_single_img(img_bgr, model_name):
9
+ print(type(img_bgr))
10
+ palette = [
11
+ ['background', [0, 0, 0]],
12
+ ['red', [255, 0, 0]]
13
+ ]
14
+
15
+ palette_dict = {}
16
+ for idx, each in enumerate(palette):
17
+ palette_dict[idx] = each[1]
18
+
19
+ if model_name == 'Mask2Former':
20
+ config_file = 'CVRP_configs/CVRP_mask2former.py'
21
+ checkpoint_file = 'checkpoint/Mask2Former.pth'
22
+ elif model_name == 'KNet':
23
+ config_file = 'CVRP_configs/CVRP_knet.py'
24
+ checkpoint_file = 'checkpoint/KNet.pth'
25
+ elif model_name == 'DeepLabV3+':
26
+ config_file = 'CVRP_configs/CVRP_deeplabv3plus.py'
27
+ checkpoint_file = 'checkpoint/DeepLabV3plus.pth'
28
+ elif model_name == 'Segformer':
29
+ config_file = 'CVRP_configs/CVRP_segformer.py'
30
+ checkpoint_file = 'checkpoint/Segformer.pth'
31
+ else:
32
+ return None, None
33
+
34
+ device = 'cuda:0'
35
+
36
+ model = init_model(config_file, checkpoint_file, device=device)
37
+
38
+ result = inference_model(model, img_bgr)
39
+ pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
40
+
41
+ pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
42
+ for idx in palette_dict.keys():
43
+ pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
44
+ pred_mask_bgr = pred_mask_bgr.astype('uint8')
45
+
46
+ pred_viz = cv2.addWeighted(img_bgr, 1, pred_mask_bgr, 1, 0)
47
+
48
+ torch.cuda.empty_cache()
49
+
50
+ return pred_viz, pred_mask_bgr
51
+ def run_segmentation(image_input, model_select):
52
+ if model_select not in ["Mask2Former", "KNet", "DeepLabV3+", "Segformer"]:
53
+ return None, None, [("No implementa", "Error"), ("", "")]
54
+ else:
55
+ color_img, binary_img = process_single_img(image_input, model_select)
56
+ return color_img, binary_img, [("", ""), ("Segmentation Finished", "normal")]
57
+
58
+ title = """<p><h1 align="center">CVRP</h1></p>"""
59
+ # 设置SAM参数
60
+
61
+ with gr.Blocks() as iface:
62
+ gr.Markdown(title)
63
+ with gr.Row():
64
+ with gr.Column():
65
+ image_input = gr.Image(interactive=True, visible=True, label="Input Image", height=360)
66
+ with gr.Row():
67
+ model_select = gr.Dropdown(choices=["Mask2Former", "KNet", "DeepLabV3+", "Segformer"], value="Mask2Former", label="Select model", visible=True)
68
+ run_button = gr.Button(value="Run", interactive=True, visible=True)
69
+ with gr.Row():
70
+ gr.Examples(
71
+ examples=[['assets/T42_1220.jpg', 'Mask2Former'], ['assets/02604.jpg', 'Mask2Former'], ['assets/T92_323.jpg', 'Mask2Former']],
72
+ inputs=[image_input, model_select])
73
+
74
+ with gr.Column():
75
+ color_output = gr.Image(interactive=False, visible=True, label="Color Image", height=360)
76
+ binary_output = gr.Image(interactive=False, visible=True, label="Binary Image", height=360)
77
+ run_status = gr.HighlightedText(
78
+ value=[("Text", "Error"), ("to be", "Label 2"), ("highlighted", "Label 3")], visible=True)
79
+
80
+ run_button.click(
81
+ fn=run_segmentation,
82
+ inputs=[image_input, model_select],
83
+ outputs=[color_output, binary_output, run_status]
84
+ )
85
+
86
+
87
+ iface.launch(debug=True, server_port=6006, server_name="127.0.0.1")