|
from transformers import AutoProcessor, Pix2StructForConditionalGeneration, Pix2StructProcessor |
|
import requests |
|
import json |
|
from PIL import Image |
|
|
|
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot") |
|
processor = AutoProcessor.from_pretrained("google/deplot") |
|
|
|
|
|
|
|
image = Image.open('222.png') |
|
|
|
inputs = processor(images=image, text="Generate underlying data table of the figure below:", return_tensors="pt") |
|
predictions = model.generate(**inputs, max_new_tokens=512) |
|
print("prediction") |
|
print(processor.decode(predictions[0], skip_special_tokens=True)) |
|
|
|
raw_output = processor.decode(predictions[0], skip_special_tokens=True) |
|
split_by_newline = raw_output.split("<0x0A>") |
|
result_array = [] |
|
|
|
for item in split_by_newline: |
|
result_array.append([x.strip() for x in item.split("|")]) |
|
|
|
print("result:") |
|
print(result_array) |
|
|
|
with open('test.log', mode='w') as file: |
|
for row in result_array: |
|
file.write(" | ".join(row) + "\n") |
|
|