Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import PIL.Image | |
import transformers | |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import torch | |
import os | |
import string | |
import functools | |
import re | |
import numpy as np | |
import spaces | |
# Model IDs | |
MODEL_IDS = { | |
"paligemma-3b-ft-widgetcap-waveui-448": "agentsea/paligemma-3b-ft-widgetcap-waveui-448", | |
"paligemma-3b-ft-waveui-896": "agentsea/paligemma-3b-ft-waveui-896" | |
} | |
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] | |
# Device configuration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load models and processors | |
models = {name: PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device) | |
for name, model_id in MODEL_IDS.items()} | |
processors = {name: PaliGemmaProcessor.from_pretrained(processor_id) | |
for name, processor_id in MODEL_IDS.items()} | |
###### Transformers Inference | |
def infer( | |
image: PIL.Image.Image, | |
text: str, | |
max_new_tokens: int, | |
model_choice: str | |
) -> str: | |
model = models[model_choice] | |
processor = processors[model_choice] | |
inputs = processor(text=text, images=image, return_tensors="pt").to(device) | |
with torch.inference_mode(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=False | |
) | |
result = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
return result[0][len(text):].lstrip("\n") | |
def parse_segmentation(input_image, input_text, model_choice): | |
out = infer(input_image, input_text, max_new_tokens=100, model_choice=model_choice) | |
objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True) | |
labels = set(obj.get('name') for obj in objs if obj.get('name')) | |
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} | |
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] | |
annotated_img = ( | |
input_image, | |
[ | |
( | |
obj['mask'] if obj.get('mask') is not None else obj['xyxy'], | |
obj['name'] or '', | |
) | |
for obj in objs | |
if 'mask' in obj or 'xyxy' in obj | |
], | |
) | |
has_annotations = bool(annotated_img[1]) | |
return annotated_img | |
######## Demo | |
INTRO_TEXT = """## PaliGemma WaveUI\n\n | |
Two fine-tuned models on the [WaveUI dataset](https://huggingface.co/datasets/agentsea/wave-ui) from different bases:\n\n | |
- [paligemma-3b-ft-widgetcap-waveui-448](https://huggingface.co/agentsea/paligemma-3b-ft-widgetcap-waveui-448) | |
- [paligemma-3b-ft-waveui-896](https://huggingface.co/agentsea/paligemma-3b-ft-waveui-896) | |
Note:\n\n | |
- the task they were fine-tuned on was detection, so it may not generalize to other tasks. | |
Usage: write the task keyword "detect" before the element you want the model to detect. For example, "detect profile picture". | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(INTRO_TEXT) | |
with gr.Tab("Detection"): | |
model_choice = gr.Dropdown(label="Select Model", choices=list(MODEL_IDS.keys())) | |
image = gr.Image(type="pil") | |
seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')") | |
seg_btn = gr.Button("Submit") | |
annotated_image = gr.AnnotatedImage(label="Output") | |
examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]] | |
gr.Examples( | |
examples=examples, | |
inputs=[image, seg_input], | |
) | |
seg_inputs = [ | |
image, | |
seg_input, | |
model_choice | |
] | |
seg_outputs = [ | |
annotated_image | |
] | |
seg_btn.click( | |
fn=parse_segmentation, | |
inputs=seg_inputs, | |
outputs=seg_outputs, | |
) | |
_SEGMENT_DETECT_RE = re.compile( | |
r'(.*?)' + | |
r'<loc(\d{4})>' * 4 + r'\s*' + | |
'(?:%s)?' % (r'<seg(\d{3})>' * 16) + | |
r'\s*([^;<>]+)? ?(?:; )?', | |
) | |
def extract_objs(text, width, height, unique_labels=False): | |
"""Returns objs for a string with "<loc>" and "<seg>" tokens.""" | |
objs = [] | |
seen = set() | |
while text: | |
m = _SEGMENT_DETECT_RE.match(text) | |
if not m: | |
break | |
print("m", m) | |
gs = list(m.groups()) | |
before = gs.pop(0) | |
name = gs.pop() | |
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] | |
y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width)) | |
mask = None | |
content = m.group() | |
if before: | |
objs.append(dict(content=before)) | |
content = content[len(before):] | |
while unique_labels and name in seen: | |
name = (name or '') + "'" | |
seen.add(name) | |
objs.append(dict( | |
content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name)) | |
text = text[len(before) + len(content):] | |
if text: | |
objs.append(dict(content=text)) | |
return objs | |
######### | |
if __name__ == "__main__": | |
demo.queue(max_size=10).launch(debug=True) |