martinsinnona commited on
Commit
e8ba5e8
·
1 Parent(s): 7b0ea0f
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -14,12 +14,14 @@ processor.image_processor.is_vqa = False
14
  model = Pix2StructForConditionalGeneration.from_pretrained("martinsinnona/visdecode_B").to(device)
15
  model.eval()
16
 
17
- def generate_caption(image):
18
 
19
- inputs = processor(images=image, return_tensors="pt", max_patches=1024).to(device)
20
- generated_ids = model.generate(flattened_patches=inputs.flattened_patches, attention_mask=inputs.attention_mask, max_length=600)
21
- generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
22
 
 
 
23
  # Generate the Vega image
24
  vega = string_to_vega(generated_caption)
25
  vega_image = draw_vega(vega)
@@ -51,7 +53,7 @@ def string_to_vega(string):
51
  # Create the Gradio interface
52
  iface = gr.Interface(
53
 
54
- fn=generate_caption,
55
  inputs=gr.Image(type="pil"),
56
  outputs=[gr.Textbox(), gr.Image(type="pil")],
57
  title="Image to Vega-Lite",
@@ -60,4 +62,4 @@ iface = gr.Interface(
60
 
61
  # Launch the interface
62
  if __name__ == "__main__":
63
- iface.launch(share=True)
 
14
  model = Pix2StructForConditionalGeneration.from_pretrained("martinsinnona/visdecode_B").to(device)
15
  model.eval()
16
 
17
+ def generate(image):
18
 
19
+ #inputs = processor(images=image, return_tensors="pt", max_patches=1024).to(device)
20
+ #generated_ids = model.generate(flattened_patches=inputs.flattened_patches, attention_mask=inputs.attention_mask, max_length=600)
21
+ #generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
22
 
23
+ generated_caption = "{'mark': 'bar', 'encoding': {'x': {'field': '', 'type': 'ordinal'}, 'y': {'field': '', 'type': 'quantitative'}}, 'data': {'values': [{'x': 0, 'y': 5.6}, {'x': 1, 'y': 6.7}, {'x': 2, 'y': 5.0}, {'x': 3, 'y': 18.7}]}}"
24
+
25
  # Generate the Vega image
26
  vega = string_to_vega(generated_caption)
27
  vega_image = draw_vega(vega)
 
53
  # Create the Gradio interface
54
  iface = gr.Interface(
55
 
56
+ fn=generate,
57
  inputs=gr.Image(type="pil"),
58
  outputs=[gr.Textbox(), gr.Image(type="pil")],
59
  title="Image to Vega-Lite",
 
62
 
63
  # Launch the interface
64
  if __name__ == "__main__":
65
+ iface.launch()