nph4rd commited on
Commit
d3c1d14
·
verified ·
1 Parent(s): c86af64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -102
app.py CHANGED
@@ -10,64 +10,50 @@ import re
10
  import numpy as np
11
  import spaces
12
 
13
- # Model IDs
14
- MODEL_IDS = {
15
- "Model 1 (Widgetcap 448)": "agentsea/paligemma-3b-ft-widgetcap-waveui-448",
16
- "Model 2 (WaveUI 896)": "agentsea/paligemma-3b-ft-waveui-896"
17
- }
18
- PROCESSOR_IDS = {
19
- "Model 1 (Widgetcap 448)": "google/paligemma-3b-pt-448",
20
- "Model 2 (WaveUI 896)": "google/paligemma-3b-pt-896"
21
- }
22
-
23
- # Device configuration
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
- # Load models and processors
27
- models = {name: PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
28
- for name, model_id in MODEL_IDS.items()}
29
- processors = {name: PaliGemmaProcessor.from_pretrained(processor_id)
30
- for name, processor_id in PROCESSOR_IDS.items()}
 
31
 
32
  ###### Transformers Inference
33
  @spaces.GPU
34
  def infer(
35
  image: PIL.Image.Image,
36
  text: str,
37
- max_new_tokens: int,
38
- model_choice: str
39
  ) -> str:
40
- model = models[model_choice]
41
- processor = processors[model_choice]
42
  inputs = processor(text=text, images=image, return_tensors="pt").to(device)
43
  with torch.inference_mode():
44
- generated_ids = model.generate(
45
- **inputs,
46
- max_new_tokens=max_new_tokens,
47
- do_sample=False
48
- )
49
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)
50
  return result[0][len(text):].lstrip("\n")
51
 
52
- def parse_segmentation(input_image, input_text, model_choice):
53
- out = infer(input_image, input_text, max_new_tokens=100, model_choice=model_choice)
54
- objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
55
- labels = set(obj.get('name') for obj in objs if obj.get('name'))
56
- color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
57
- highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
58
- annotated_img = (
59
- input_image,
60
- [
61
- (
62
- obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
63
- obj['name'] or '',
64
- )
65
- for obj in objs
66
- if 'mask' in obj or 'xyxy'
67
- ],
68
- )
69
- has_annotations = bool(annotated_img[1])
70
- return annotated_img
71
 
72
  ######## Demo
73
 
@@ -80,34 +66,34 @@ Note:\n\n
80
  - the task it was fine-tuned on was detection, so it may not generalize to other tasks.
81
  """
82
 
 
83
  with gr.Blocks(css="style.css") as demo:
84
- gr.Markdown(INTRO_TEXT)
85
- with gr.Tab("Detection"):
86
- model_choice = gr.Dropdown(label="Select Model", choices=list(MODEL_IDS.keys()))
87
- image = gr.Image(type="pil")
88
- seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')")
89
- seg_btn = gr.Button("Submit")
90
- annotated_image = gr.AnnotatedImage(label="Output")
91
-
92
- examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]]
93
- gr.Examples(
94
- examples=examples,
95
- inputs=[image, seg_input],
96
- )
97
 
98
- seg_inputs = [
99
- image,
100
- seg_input,
101
- model_choice
102
- ]
103
- seg_outputs = [
104
- annotated_image
105
  ]
106
- seg_btn.click(
107
- fn=parse_segmentation,
108
- inputs=seg_inputs,
109
- outputs=seg_outputs,
110
- )
 
 
 
 
111
 
112
  _SEGMENT_DETECT_RE = re.compile(
113
  r'(.*?)' +
@@ -117,39 +103,39 @@ _SEGMENT_DETECT_RE = re.compile(
117
  )
118
 
119
  def extract_objs(text, width, height, unique_labels=False):
120
- """Returns objs for a string with "<loc>" and "<seg>" tokens."""
121
- objs = []
122
- seen = set()
123
- while text:
124
- m = _SEGMENT_DETECT_RE.match(text)
125
- if not m:
126
- break
127
- print("m", m)
128
- gs = list(m.groups())
129
- before = gs.pop(0)
130
- name = gs.pop()
131
- y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
132
-
133
- y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
134
- mask = None
135
-
136
- content = m.group()
137
- if before:
138
- objs.append(dict(content=before))
139
- content = content[len(before):]
140
- while unique_labels and name in seen:
141
- name = (name or '') + "'"
142
- seen.add(name)
143
- objs.append(dict(
144
- content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
145
- text = text[len(before) + len(content):]
146
-
147
- if text:
148
- objs.append(dict(content=text))
149
-
150
- return objs
151
 
152
  #########
153
 
154
  if __name__ == "__main__":
155
- demo.queue(max_size=10).launch(debug=True)
 
10
  import numpy as np
11
  import spaces
12
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ model_id = "agentsea/paligemma-3b-ft-widgetcap-waveui-448"
15
+ processor_id = "google/paligemma-3b-pt-448"
16
+ COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
19
+ processor = PaliGemmaProcessor.from_pretrained(processor_id)
20
 
21
  ###### Transformers Inference
22
  @spaces.GPU
23
  def infer(
24
  image: PIL.Image.Image,
25
  text: str,
26
+ max_new_tokens: int
 
27
  ) -> str:
 
 
28
  inputs = processor(text=text, images=image, return_tensors="pt").to(device)
29
  with torch.inference_mode():
30
+ generated_ids = model.generate(
31
+ **inputs,
32
+ max_new_tokens=max_new_tokens,
33
+ do_sample=False
34
+ )
35
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)
36
  return result[0][len(text):].lstrip("\n")
37
 
38
+ def parse_segmentation(input_image, input_text):
39
+ out = infer(input_image, input_text, max_new_tokens=100)
40
+ objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
41
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
42
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
43
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
44
+ annotated_img = (
45
+ input_image,
46
+ [
47
+ (
48
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
49
+ obj['name'] or '',
50
+ )
51
+ for obj in objs
52
+ if 'mask' in obj or 'xyxy' in obj
53
+ ],
54
+ )
55
+ has_annotations = bool(annotated_img[1])
56
+ return annotated_img
57
 
58
  ######## Demo
59
 
 
66
  - the task it was fine-tuned on was detection, so it may not generalize to other tasks.
67
  """
68
 
69
+
70
  with gr.Blocks(css="style.css") as demo:
71
+ gr.Markdown(INTRO_TEXT)
72
+ with gr.Tab("Detection"):
73
+ image = gr.Image(type="pil")
74
+ seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')")
75
+ seg_btn = gr.Button("Submit")
76
+ annotated_image = gr.AnnotatedImage(label="Output")
77
+
78
+ examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]]
79
+ gr.Examples(
80
+ examples=examples,
81
+ inputs=[image, seg_input],
82
+ )
 
83
 
84
+ seg_inputs = [
85
+ image,
86
+ seg_input
 
 
 
 
87
  ]
88
+ seg_outputs = [
89
+ annotated_image
90
+ ]
91
+ seg_btn.click(
92
+ fn=parse_segmentation,
93
+ inputs=seg_inputs,
94
+ outputs=seg_outputs,
95
+ )
96
+
97
 
98
  _SEGMENT_DETECT_RE = re.compile(
99
  r'(.*?)' +
 
103
  )
104
 
105
  def extract_objs(text, width, height, unique_labels=False):
106
+ """Returns objs for a string with "<loc>" and "<seg>" tokens."""
107
+ objs = []
108
+ seen = set()
109
+ while text:
110
+ m = _SEGMENT_DETECT_RE.match(text)
111
+ if not m:
112
+ break
113
+ print("m", m)
114
+ gs = list(m.groups())
115
+ before = gs.pop(0)
116
+ name = gs.pop()
117
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
118
+
119
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
120
+ mask = None
121
+
122
+ content = m.group()
123
+ if before:
124
+ objs.append(dict(content=before))
125
+ content = content[len(before):]
126
+ while unique_labels and name in seen:
127
+ name = (name or '') + "'"
128
+ seen.add(name)
129
+ objs.append(dict(
130
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
131
+ text = text[len(before) + len(content):]
132
+
133
+ if text:
134
+ objs.append(dict(content=text))
135
+
136
+ return objs
137
 
138
  #########
139
 
140
  if __name__ == "__main__":
141
+ demo.queue(max_size=10).launch(debug=True)