Simon Le Goff commited on
Commit
6ffeb01
1 Parent(s): 92d8e0c

Try with pollen-vision demo app now that the image builds properly.

Browse files
Files changed (1) hide show
  1. app.py +115 -4
app.py CHANGED
@@ -1,7 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
1
+ # import gradio as gr
2
+
3
+ # def greet(name):
4
+ # return "Hello " + name + "!!"
5
+
6
+ # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ # iface.launch()
8
+
9
+ """
10
+ Gradio app for pollen-vision
11
+
12
+ This script creates a Gradio app for pollen-vision. The app allows users to perform object detection and object segmentation using the OWL-ViT and MobileSAM models.
13
+ """
14
+
15
+ from datasets import load_dataset
16
  import gradio as gr
17
 
18
+ import numpy as np
19
+ import numpy.typing as npt
20
+ from typing import Any, Dict, List
21
+
22
+ from pollen_vision.vision_models.object_detection import OwlVitWrapper
23
+ from pollen_vision.vision_models.object_segmentation import MobileSamWrapper
24
+ from pollen_vision.vision_models.utils import Annotator, get_bboxes
25
+
26
+
27
+ owl_vit = OwlVitWrapper()
28
+ mobile_sam = MobileSamWrapper()
29
+ annotator = Annotator()
30
+
31
+
32
+ def object_detection(
33
+ img: npt.NDArray[np.uint8], text_queries: List[str], score_threshold: float
34
+ ) -> List[Dict[str, Any]]:
35
+ predictions: List[Dict[str, Any]] = owl_vit.infer(
36
+ im=img, candidate_labels=text_queries, detection_threshold=score_threshold
37
+ )
38
+ return predictions
39
+
40
+
41
+ def object_segmentation(
42
+ img: npt.NDArray[np.uint8], object_detection_predictions: List[Dict[str, Any]]
43
+ ) -> List[npt.NDArray[np.uint8]]:
44
+ bboxes = get_bboxes(predictions=object_detection_predictions)
45
+ masks: List[npt.NDArray[np.uint8]] = mobile_sam.infer(im=img, bboxes=bboxes)
46
+ return masks
47
+
48
+
49
+ def query(
50
+ task: str,
51
+ img: npt.NDArray[np.uint8],
52
+ text_queries: List[str],
53
+ score_threshold: float,
54
+ ) -> npt.NDArray[np.uint8]:
55
+ object_detection_predictions = object_detection(
56
+ img=img, text_queries=text_queries, score_threshold=score_threshold
57
+ )
58
+
59
+ if task == "Object detection + segmentation (OWL-ViT + MobileSAM)":
60
+ masks = object_segmentation(
61
+ img=img, object_detection_predictions=object_detection_predictions
62
+ )
63
+ img = annotator.annotate(
64
+ im=img, detection_predictions=object_detection_predictions, masks=masks
65
+ )
66
+ return img
67
+
68
+ img = annotator.annotate(im=img, detection_predictions=object_detection_predictions)
69
+ return img
70
+
71
+
72
+ description = """
73
+ Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam nec purus et nunc tincidunt tincidunt.
74
+ """
75
+
76
+ demo_inputs = [
77
+ gr.Dropdown(
78
+ [
79
+ "Object detection (OWL-ViT)",
80
+ "Object detection + segmentation (OWL-ViT + MobileSAM)",
81
+ ],
82
+ label="Choose a task",
83
+ value="Object detection (OWL-ViT)",
84
+ ),
85
+ gr.Image(),
86
+ "text",
87
+ gr.Slider(0, 1, value=0.1),
88
+ ]
89
+
90
+ rdt_dataset = load_dataset("pollen-robotics/reachy-doing-things", split="train")
91
+
92
+ img_kitchen_detection = rdt_dataset[11]["image"]
93
+ img_kitchen_segmentation = rdt_dataset[12]["image"]
94
+
95
+ demo_examples = [
96
+ [
97
+ "Object detection (OWL-ViT)",
98
+ img_kitchen_detection,
99
+ ["kettle", "black mug", "sink", "blue mug", "sponge", "bag of chips"],
100
+ 0.15,
101
+ ],
102
+ [
103
+ "Object detection + segmentation (OWL-ViT + MobileSAM)",
104
+ img_kitchen_segmentation,
105
+ ["blue mug", "paper cup", "kettle", "sponge"],
106
+ 0.12,
107
+ ],
108
+ ]
109
 
110
+ demo = gr.Interface(
111
+ fn=query,
112
+ inputs=demo_inputs,
113
+ outputs="image",
114
+ title="pollen-vision",
115
+ description=description,
116
+ examples=demo_examples,
117
+ )
118
+ demo.launch()