nph4rd's picture
Update app.py
1761ad9 verified
raw
history blame
4.98 kB
import gradio as gr
import PIL.Image
import transformers
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import os
import string
import functools
import re
import numpy as np
import spaces
# Model IDs
MODEL_IDS = {
"Model 1 (Widgetcap 448)": "agentsea/paligemma-3b-ft-widgetcap-waveui-448",
"Model 2 (WaveUI 896)": "agentsea/paligemma-3b-ft-waveui-896"
}
PROCESSOR_IDS = {
"Model 1 (Widgetcap 448)": "google/paligemma-3b-pt-448",
"Model 2 (WaveUI 896)": "google/paligemma-3b-pt-896"
}
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load models and processors
models = {name: PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
for name, model_id in MODEL_IDS.items()}
processors = {name: PaliGemmaProcessor.from_pretrained(processor_id)
for name, processor_id in PROCESSOR_IDS.items()}
###### Transformers Inference
@spaces.GPU
def infer(
image: PIL.Image.Image,
text: str,
max_new_tokens: int,
model_choice: str
) -> str:
model = models[model_choice]
processor = processors[model_choice]
inputs = processor(text=text, images=image, return_tensors="pt").to(device)
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)
return result[0][len(text):].lstrip("\n")
def parse_segmentation(input_image, input_text, model_choice):
out = infer(input_image, input_text, max_new_tokens=100, model_choice=model_choice)
objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
labels = set(obj.get('name') for obj in objs if obj.get('name'))
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
annotated_img = (
input_image,
[
(
obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
obj['name'] or '',
)
for obj in objs
if 'mask' in obj or 'xyxy' in obj
],
)
has_annotations = bool(annotated_img[1])
return annotated_img
######## Demo
INTRO_TEXT = """## PaliGemma WaveUI\n\n
[PaliGemma Widgetcap 448](https://huggingface.co/google/paligemma-3b-ft-widgetcap-448) fine-tuned on the [WaveUI-25k](https://huggingface.co/datasets/agentsea/wave-ui-25k) dataset for UI element detection.\n\n
Note:\n\n
- this model is fine-tuned on a subset of the WaveUI dataset and may not generalize to all UI elements.
- the task it was fine-tuned on was detection, so it may not generalize to other tasks.
"""
with gr.Blocks(css="style.css") as demo:
gr.Markdown(INTRO_TEXT)
with gr.Tab("Detection"):
model_choice = gr.Dropdown(label="Select Model", choices=list(MODEL_IDS.keys()))
image = gr.Image(type="pil")
seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')")
seg_btn = gr.Button("Submit")
annotated_image = gr.AnnotatedImage(label="Output")
examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]]
gr.Examples(
examples=examples,
inputs=[image, seg_input],
)
seg_inputs = [
image,
seg_input,
model_choice
]
seg_outputs = [
annotated_image
]
seg_btn.click(
fn=parse_segmentation,
inputs=seg_inputs,
outputs=seg_outputs,
)
_SEGMENT_DETECT_RE = re.compile(
r'(.*?)' +
r'<loc(\d{4})>' * 4 + r'\s*' +
'(?:%s)?' % (r'<seg(\d{3})>' * 16) +
r'\s*([^;<>]+)? ?(?:; )?',
)
def extract_objs(text, width, height, unique_labels=False):
"""Returns objs for a string with "<loc>" and "<seg>" tokens."""
objs = []
seen = set()
while text:
m = _SEGMENT_DETECT_RE.match(text)
if not m:
break
print("m", m)
gs = list(m.groups())
before = gs.pop(0)
name = gs.pop()
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
mask = None
content = m.group()
if before:
objs.append(dict(content=before))
content = content[len(before):]
while unique_labels and name in seen:
name = (name or '') + "'"
seen.add(name)
objs.append(dict(
content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
text = text[len(before) + len(content):]
if text:
objs.append(dict(content=text))
return objs
#########
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)