nph4rd commited on
Commit
1761ad9
·
verified ·
1 Parent(s): d3c1d14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -87
app.py CHANGED
@@ -10,50 +10,64 @@ import re
10
  import numpy as np
11
  import spaces
12
 
13
-
14
- model_id = "agentsea/paligemma-3b-ft-widgetcap-waveui-448"
15
- processor_id = "google/paligemma-3b-pt-448"
16
- COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
 
 
 
 
 
 
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
19
- processor = PaliGemmaProcessor.from_pretrained(processor_id)
 
 
 
 
20
 
21
  ###### Transformers Inference
22
  @spaces.GPU
23
  def infer(
24
  image: PIL.Image.Image,
25
  text: str,
26
- max_new_tokens: int
 
27
  ) -> str:
 
 
28
  inputs = processor(text=text, images=image, return_tensors="pt").to(device)
29
  with torch.inference_mode():
30
- generated_ids = model.generate(
31
- **inputs,
32
- max_new_tokens=max_new_tokens,
33
- do_sample=False
34
- )
35
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)
36
  return result[0][len(text):].lstrip("\n")
37
 
38
- def parse_segmentation(input_image, input_text):
39
- out = infer(input_image, input_text, max_new_tokens=100)
40
- objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
41
- labels = set(obj.get('name') for obj in objs if obj.get('name'))
42
- color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
43
- highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
44
- annotated_img = (
45
- input_image,
46
- [
47
- (
48
- obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
49
- obj['name'] or '',
50
- )
51
- for obj in objs
52
- if 'mask' in obj or 'xyxy' in obj
53
- ],
54
- )
55
- has_annotations = bool(annotated_img[1])
56
- return annotated_img
57
 
58
  ######## Demo
59
 
@@ -66,34 +80,34 @@ Note:\n\n
66
  - the task it was fine-tuned on was detection, so it may not generalize to other tasks.
67
  """
68
 
69
-
70
  with gr.Blocks(css="style.css") as demo:
71
- gr.Markdown(INTRO_TEXT)
72
- with gr.Tab("Detection"):
73
- image = gr.Image(type="pil")
74
- seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')")
75
- seg_btn = gr.Button("Submit")
76
- annotated_image = gr.AnnotatedImage(label="Output")
77
-
78
- examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]]
79
- gr.Examples(
80
- examples=examples,
81
- inputs=[image, seg_input],
82
- )
 
83
 
84
- seg_inputs = [
85
- image,
86
- seg_input
 
87
  ]
88
- seg_outputs = [
89
- annotated_image
90
- ]
91
- seg_btn.click(
92
- fn=parse_segmentation,
93
- inputs=seg_inputs,
94
- outputs=seg_outputs,
95
- )
96
-
97
 
98
  _SEGMENT_DETECT_RE = re.compile(
99
  r'(.*?)' +
@@ -103,37 +117,37 @@ _SEGMENT_DETECT_RE = re.compile(
103
  )
104
 
105
  def extract_objs(text, width, height, unique_labels=False):
106
- """Returns objs for a string with "<loc>" and "<seg>" tokens."""
107
- objs = []
108
- seen = set()
109
- while text:
110
- m = _SEGMENT_DETECT_RE.match(text)
111
- if not m:
112
- break
113
- print("m", m)
114
- gs = list(m.groups())
115
- before = gs.pop(0)
116
- name = gs.pop()
117
- y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
118
-
119
- y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
120
- mask = None
121
-
122
- content = m.group()
123
- if before:
124
- objs.append(dict(content=before))
125
- content = content[len(before):]
126
- while unique_labels and name in seen:
127
- name = (name or '') + "'"
128
- seen.add(name)
129
- objs.append(dict(
130
- content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
131
- text = text[len(before) + len(content):]
132
-
133
- if text:
134
- objs.append(dict(content=text))
135
-
136
- return objs
137
 
138
  #########
139
 
 
10
  import numpy as np
11
  import spaces
12
 
13
+ # Model IDs
14
+ MODEL_IDS = {
15
+ "Model 1 (Widgetcap 448)": "agentsea/paligemma-3b-ft-widgetcap-waveui-448",
16
+ "Model 2 (WaveUI 896)": "agentsea/paligemma-3b-ft-waveui-896"
17
+ }
18
+ PROCESSOR_IDS = {
19
+ "Model 1 (Widgetcap 448)": "google/paligemma-3b-pt-448",
20
+ "Model 2 (WaveUI 896)": "google/paligemma-3b-pt-896"
21
+ }
22
+
23
+ # Device configuration
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ # Load models and processors
27
+ models = {name: PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
28
+ for name, model_id in MODEL_IDS.items()}
29
+ processors = {name: PaliGemmaProcessor.from_pretrained(processor_id)
30
+ for name, processor_id in PROCESSOR_IDS.items()}
31
 
32
  ###### Transformers Inference
33
  @spaces.GPU
34
  def infer(
35
  image: PIL.Image.Image,
36
  text: str,
37
+ max_new_tokens: int,
38
+ model_choice: str
39
  ) -> str:
40
+ model = models[model_choice]
41
+ processor = processors[model_choice]
42
  inputs = processor(text=text, images=image, return_tensors="pt").to(device)
43
  with torch.inference_mode():
44
+ generated_ids = model.generate(
45
+ **inputs,
46
+ max_new_tokens=max_new_tokens,
47
+ do_sample=False
48
+ )
49
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)
50
  return result[0][len(text):].lstrip("\n")
51
 
52
+ def parse_segmentation(input_image, input_text, model_choice):
53
+ out = infer(input_image, input_text, max_new_tokens=100, model_choice=model_choice)
54
+ objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
55
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
56
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
57
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
58
+ annotated_img = (
59
+ input_image,
60
+ [
61
+ (
62
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
63
+ obj['name'] or '',
64
+ )
65
+ for obj in objs
66
+ if 'mask' in obj or 'xyxy' in obj
67
+ ],
68
+ )
69
+ has_annotations = bool(annotated_img[1])
70
+ return annotated_img
71
 
72
  ######## Demo
73
 
 
80
  - the task it was fine-tuned on was detection, so it may not generalize to other tasks.
81
  """
82
 
 
83
  with gr.Blocks(css="style.css") as demo:
84
+ gr.Markdown(INTRO_TEXT)
85
+ with gr.Tab("Detection"):
86
+ model_choice = gr.Dropdown(label="Select Model", choices=list(MODEL_IDS.keys()))
87
+ image = gr.Image(type="pil")
88
+ seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')")
89
+ seg_btn = gr.Button("Submit")
90
+ annotated_image = gr.AnnotatedImage(label="Output")
91
+
92
+ examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]]
93
+ gr.Examples(
94
+ examples=examples,
95
+ inputs=[image, seg_input],
96
+ )
97
 
98
+ seg_inputs = [
99
+ image,
100
+ seg_input,
101
+ model_choice
102
  ]
103
+ seg_outputs = [
104
+ annotated_image
105
+ ]
106
+ seg_btn.click(
107
+ fn=parse_segmentation,
108
+ inputs=seg_inputs,
109
+ outputs=seg_outputs,
110
+ )
 
111
 
112
  _SEGMENT_DETECT_RE = re.compile(
113
  r'(.*?)' +
 
117
  )
118
 
119
  def extract_objs(text, width, height, unique_labels=False):
120
+ """Returns objs for a string with "<loc>" and "<seg>" tokens."""
121
+ objs = []
122
+ seen = set()
123
+ while text:
124
+ m = _SEGMENT_DETECT_RE.match(text)
125
+ if not m:
126
+ break
127
+ print("m", m)
128
+ gs = list(m.groups())
129
+ before = gs.pop(0)
130
+ name = gs.pop()
131
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
132
+
133
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
134
+ mask = None
135
+
136
+ content = m.group()
137
+ if before:
138
+ objs.append(dict(content=before))
139
+ content = content[len(before):]
140
+ while unique_labels and name in seen:
141
+ name = (name or '') + "'"
142
+ seen.add(name)
143
+ objs.append(dict(
144
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
145
+ text = text[len(before) + len(content):]
146
+
147
+ if text:
148
+ objs.append(dict(content=text))
149
+
150
+ return objs
151
 
152
  #########
153