nph4rd commited on
Commit
286ea0b
·
verified ·
1 Parent(s): bb19caf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -86
app.py CHANGED
@@ -10,50 +10,63 @@ import re
10
  import numpy as np
11
  import spaces
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- model_id = "agentsea/paligemma-3b-ft-widgetcap-waveui-448"
15
- processor_id = "google/paligemma-3b-pt-448"
16
- COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
19
- processor = PaliGemmaProcessor.from_pretrained(processor_id)
20
 
21
  ###### Transformers Inference
22
  @spaces.GPU
23
  def infer(
24
  image: PIL.Image.Image,
25
  text: str,
26
- max_new_tokens: int
 
27
  ) -> str:
 
 
28
  inputs = processor(text=text, images=image, return_tensors="pt").to(device)
29
  with torch.inference_mode():
30
- generated_ids = model.generate(
31
- **inputs,
32
- max_new_tokens=max_new_tokens,
33
- do_sample=False
34
- )
35
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)
36
  return result[0][len(text):].lstrip("\n")
37
 
38
- def parse_segmentation(input_image, input_text):
39
- out = infer(input_image, input_text, max_new_tokens=100)
40
- objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
41
- labels = set(obj.get('name') for obj in objs if obj.get('name'))
42
- color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
43
- highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
44
- annotated_img = (
45
- input_image,
46
- [
47
- (
48
- obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
49
- obj['name'] or '',
50
- )
51
- for obj in objs
52
- if 'mask' in obj or 'xyxy' in obj
53
- ],
54
- )
55
- has_annotations = bool(annotated_img[1])
56
- return annotated_img
57
 
58
  ######## Demo
59
 
@@ -66,33 +79,34 @@ Note:\n\n
66
  - the task it was fine-tuned on was detection, so it may not generalize to other tasks.
67
  """
68
 
69
-
70
  with gr.Blocks(css="style.css") as demo:
71
- gr.Markdown(INTRO_TEXT)
72
- with gr.Tab("Detection"):
73
- image = gr.Image(type="pil")
74
- seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')")
75
- seg_btn = gr.Button("Submit")
76
- annotated_image = gr.AnnotatedImage(label="Output")
77
-
78
- examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]]
79
- gr.Examples(
80
- examples=examples,
81
- inputs=[image, seg_input],
82
- )
 
83
 
84
- seg_inputs = [
85
- image,
86
- seg_input
 
87
  ]
88
- seg_outputs = [
89
- annotated_image
90
- ]
91
- seg_btn.click(
92
- fn=parse_segmentation,
93
- inputs=seg_inputs,
94
- outputs=seg_outputs,
95
- )
96
 
97
 
98
  _SEGMENT_DETECT_RE = re.compile(
@@ -103,39 +117,39 @@ _SEGMENT_DETECT_RE = re.compile(
103
  )
104
 
105
  def extract_objs(text, width, height, unique_labels=False):
106
- """Returns objs for a string with "<loc>" and "<seg>" tokens."""
107
- objs = []
108
- seen = set()
109
- while text:
110
- m = _SEGMENT_DETECT_RE.match(text)
111
- if not m:
112
- break
113
- print("m", m)
114
- gs = list(m.groups())
115
- before = gs.pop(0)
116
- name = gs.pop()
117
- y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
118
-
119
- y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
120
- mask = None
121
-
122
- content = m.group()
123
- if before:
124
- objs.append(dict(content=before))
125
- content = content[len(before):]
126
- while unique_labels and name in seen:
127
- name = (name or '') + "'"
128
- seen.add(name)
129
- objs.append(dict(
130
- content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
131
- text = text[len(before) + len(content):]
132
-
133
- if text:
134
- objs.append(dict(content=text))
135
-
136
- return objs
137
 
138
  #########
139
 
140
  if __name__ == "__main__":
141
- demo.queue(max_size=10).launch(debug=True)
 
10
  import numpy as np
11
  import spaces
12
 
13
+ # Model IDs
14
+ MODEL_IDS = {
15
+ "Model 1 (Widgetcap 448)": "agentsea/paligemma-3b-ft-widgetcap-waveui-448",
16
+ "Model 2 (WaveUI 896)": "agentsea/paligemma-3b-ft-waveui-896"
17
+ }
18
+ PROCESSOR_IDS = {
19
+ "Model 1 (Widgetcap 448)": "google/paligemma-3b-pt-448",
20
+ "Model 2 (WaveUI 896)": "google/paligemma-3b-pt-896"
21
+ }
22
+
23
+ # Load models and processors
24
+ models = {name: PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
25
+ for name, model_id in MODEL_IDS.items()}
26
+ processors = {name: PaliGemmaProcessor.from_pretrained(processor_id)
27
+ for name, processor_id in PROCESSOR_IDS.items()}
28
 
 
 
 
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
30
 
31
  ###### Transformers Inference
32
  @spaces.GPU
33
  def infer(
34
  image: PIL.Image.Image,
35
  text: str,
36
+ max_new_tokens: int,
37
+ model_choice: str
38
  ) -> str:
39
+ model = models[model_choice]
40
+ processor = processors[model_choice]
41
  inputs = processor(text=text, images=image, return_tensors="pt").to(device)
42
  with torch.inference_mode():
43
+ generated_ids = model.generate(
44
+ **inputs,
45
+ max_new_tokens=max_new_tokens,
46
+ do_sample=False
47
+ )
48
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)
49
  return result[0][len(text):].lstrip("\n")
50
 
51
+ def parse_segmentation(input_image, input_text, model_choice):
52
+ out = infer(input_image, input_text, max_new_tokens=100, model_choice=model_choice)
53
+ objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
54
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
55
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
56
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
57
+ annotated_img = (
58
+ input_image,
59
+ [
60
+ (
61
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
62
+ obj['name'] or '',
63
+ )
64
+ for obj in objs
65
+ if 'mask' in obj or 'xyxy' in obj
66
+ ],
67
+ )
68
+ has_annotations = bool(annotated_img[1])
69
+ return annotated_img
70
 
71
  ######## Demo
72
 
 
79
  - the task it was fine-tuned on was detection, so it may not generalize to other tasks.
80
  """
81
 
 
82
  with gr.Blocks(css="style.css") as demo:
83
+ gr.Markdown(INTRO_TEXT)
84
+ with gr.Tab("Detection"):
85
+ model_choice = gr.Dropdown(label="Select Model", choices=list(MODEL_IDS.keys()))
86
+ image = gr.Image(type="pil")
87
+ seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')")
88
+ seg_btn = gr.Button("Submit")
89
+ annotated_image = gr.AnnotatedImage(label="Output")
90
+
91
+ examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]]
92
+ gr.Examples(
93
+ examples=examples,
94
+ inputs=[image, seg_input],
95
+ )
96
 
97
+ seg_inputs = [
98
+ image,
99
+ seg_input,
100
+ model_choice
101
  ]
102
+ seg_outputs = [
103
+ annotated_image
104
+ ]
105
+ seg_btn.click(
106
+ fn=parse_segmentation,
107
+ inputs=seg_inputs,
108
+ outputs=seg_outputs,
109
+ )
110
 
111
 
112
  _SEGMENT_DETECT_RE = re.compile(
 
117
  )
118
 
119
  def extract_objs(text, width, height, unique_labels=False):
120
+ """Returns objs for a string with "<loc>" and "<seg>" tokens."""
121
+ objs = []
122
+ seen = set()
123
+ while text:
124
+ m = _SEGMENT_DETECT_RE.match(text)
125
+ if not m:
126
+ break
127
+ print("m", m)
128
+ gs = list(m.groups())
129
+ before = gs.pop(0)
130
+ name = gs.pop()
131
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
132
+
133
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
134
+ mask = None
135
+
136
+ content = m.group()
137
+ if before:
138
+ objs.append(dict(content=before))
139
+ content = content[len(before):]
140
+ while unique_labels and name in seen:
141
+ name = (name or '') + "'"
142
+ seen.add(name)
143
+ objs.append(dict(
144
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
145
+ text = text[len(before) + len(content):]
146
+
147
+ if text:
148
+ objs.append(dict(content=text))
149
+
150
+ return objs
151
 
152
  #########
153
 
154
  if __name__ == "__main__":
155
+ demo.queue(max_size=10).launch(debug=True)