Spaces:
Sleeping
Sleeping
File size: 2,251 Bytes
143c05b 19dfe9f 7b0ea0f 4fd791e 7b0ea0f 143c05b 19dfe9f 7b0ea0f 19dfe9f 143c05b e8ba5e8 19dfe9f e8ba5e8 7b0ea0f e8ba5e8 7b0ea0f a1cce6c 7b0ea0f 3927f0d 19dfe9f 7b0ea0f a1cce6c 7b0ea0f 19dfe9f 7b0ea0f e8ba5e8 19dfe9f 3927f0d 3a3c7b1 fb3f4a6 7b0ea0f 19dfe9f 4fd791e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
import torch
from PIL import Image
import json
import vl_convert as vlc
from io import BytesIO
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the processor and model
processor = AutoProcessor.from_pretrained("google/matcha-base")
processor.image_processor.is_vqa = False
model = Pix2StructForConditionalGeneration.from_pretrained("martinsinnona/visdecode_B").to(device)
model.eval()
def generate(image):
#inputs = processor(images=image, return_tensors="pt", max_patches=1024).to(device)
#generated_ids = model.generate(flattened_patches=inputs.flattened_patches, attention_mask=inputs.attention_mask, max_length=600)
#generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
generated_caption = "{'mark': 'bar', 'encoding': {'x': {'field': '', 'type': 'ordinal'}, 'y': {'field': '', 'type': 'quantitative'}}, 'data': {'values': [{'x': 0, 'y': 5.6}, {'x': 1, 'y': 6.7}, {'x': 2, 'y': 5.0}, {'x': 3, 'y': 18.7}]}}"
# Generate the Vega image
vega = string_to_vega(generated_caption)
vega_image = draw_vega(vega)
return generated_caption, vega_image
def draw_vega(vega, scale=3):
spec = json.dumps(vega, indent=4)
png_data = vlc.vegalite_to_png(vl_spec=spec, scale=scale)
return Image.open(BytesIO(png_data)).thumbnail((500,500))
def string_to_vega(string):
string = string.replace("'", "\"")
vega = json.loads(string)
for axis in ["x", "y"]:
field = vega["encoding"][axis]["field"]
if field == "":
vega["encoding"][axis]["field"] = axis
vega["encoding"][axis]["title"] = ""
else:
for entry in vega["data"]["values"]:
entry[field] = entry.pop(axis)
return vega
# Create the Gradio interface
iface = gr.Interface(
fn=generate,
inputs=gr.Image(type="pil"),
outputs=[gr.Textbox(),
gr.Image(type="pil", label="Generated Vega Image", height=500, width=500)],
title="Image to Vega-Lite",
description="Upload an image to generate vega-lite"
)
# Launch the interface
if __name__ == "__main__":
iface.launch(share=True)
|