taesiri commited on
Commit
8464e03
·
1 Parent(s): f9b1883

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import numpy as np
4
+ import datasets
5
+
6
+ bugs_ds = datasets.load_dataset("asgaardlab/SampleDataset", split="validation")
7
+
8
+
9
+ def generate_annotations(image_index):
10
+ image_index = int(image_index)
11
+ objects_json = bugs_ds[image_index]["Objects JSON"]
12
+ objects = json.loads(objects_json)
13
+
14
+ segmentation_image_rgb = bugs_ds[image_index]["Segmentation Image"]
15
+ segmentation_image_rgb = np.array(segmentation_image_rgb)
16
+
17
+ annotations = []
18
+ for obj in objects:
19
+ color = tuple(obj["color"].values())[:-1]
20
+ mask = np.all(segmentation_image_rgb == np.array(color), axis=-1).astype(
21
+ np.float32
22
+ )
23
+ annotations.append((mask, obj["name"]))
24
+
25
+ object_count = bugs_ds[image_index]["Object Count"]
26
+ victim_name = bugs_ds[image_index]["Victim Name"]
27
+ bug_type = bugs_ds[image_index]["Tag"]
28
+
29
+ return (
30
+ (bugs_ds[image_index]["Correct Image"], annotations),
31
+ objects,
32
+ object_count,
33
+ victim_name,
34
+ bug_type,
35
+ )
36
+
37
+
38
+ # Setting up the Gradio interface using blocks API
39
+ with gr.Blocks() as demo:
40
+ gr.Markdown(
41
+ "Enter the image index and click **Submit** to view the segmentation annotations."
42
+ )
43
+ with gr.Row():
44
+ inp = gr.Slider(
45
+ minimum=0, maximum=len(bugs_ds) - 1, step=1, label="Image Index"
46
+ )
47
+ btn = gr.Button("Submit")
48
+ with gr.Row():
49
+ with gr.Column():
50
+ object_count = gr.Number(label="Object Count")
51
+ victim_name = gr.Textbox(label="Victim Name")
52
+ bug_type = gr.Textbox(label="Bug Type")
53
+
54
+ seg_img = gr.AnnotatedImage()
55
+
56
+ with gr.Row():
57
+ json_data = gr.JSON()
58
+
59
+ btn.click(
60
+ fn=generate_annotations,
61
+ inputs=inp,
62
+ outputs=[seg_img, json_data, object_count, victim_name, bug_type],
63
+ )
64
+
65
+ demo.launch()