Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoProcessor, Pix2StructForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, Pix2StructProcessor, BartConfig,ViTConfig,VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, AutoTokenizer, AutoModel
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import warnings
|
6 |
+
import re
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
from tqdm import tqdm
|
12 |
+
import argparse
|
13 |
+
from scipy import optimize
|
14 |
+
from typing import Optional
|
15 |
+
import dataclasses
|
16 |
+
import editdistance
|
17 |
+
import itertools
|
18 |
+
import sys
|
19 |
+
import time
|
20 |
+
import logging
|
21 |
+
import subprocess
|
22 |
+
import spaces
|
23 |
+
import openai
|
24 |
+
import base64
|
25 |
+
from io import StringIO
|
26 |
+
|
27 |
+
# Git LFS pull λͺ
λ Ήμ΄ μ€ν
|
28 |
+
result = subprocess.run(['git', 'lfs', 'pull'], capture_output=True, text=True)
|
29 |
+
|
30 |
+
# λͺ
λ Ήμ΄ μ€ν κ²°κ³Ό μΆλ ₯ (μ ν μ¬ν)
|
31 |
+
if result.returncode == 0:
|
32 |
+
print("LFS νμΌμ΄ μ±κ³΅μ μΌλ‘ λ€μ΄λ‘λλμμ΅λλ€.")
|
33 |
+
else:
|
34 |
+
print(f"μ€λ₯ λ°μ: {result.stderr}")
|
35 |
+
|
36 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
37 |
+
logger = logging.getLogger()
|
38 |
+
|
39 |
+
warnings.filterwarnings('ignore')
|
40 |
+
MAX_PATCHES = 512
|
41 |
+
# Load the models and processor
|
42 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
+
|
44 |
+
# Paths to the models
|
45 |
+
ko_deplot_model_path = './deplot_model_ver_24.11.21_korean_only(exclude NUUA)_epoch1.bin'
|
46 |
+
|
47 |
+
# Load first model ko-deplot
|
48 |
+
def load_model1():
|
49 |
+
processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
|
50 |
+
model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
|
51 |
+
model1.load_state_dict(torch.load(ko_deplot_model_path, map_location="cpu"))
|
52 |
+
model1.to(torch.device("cuda"))
|
53 |
+
return processor1, model1
|
54 |
+
|
55 |
+
processor1, model1 = load_model1()
|
56 |
+
|
57 |
+
# Function to format output
|
58 |
+
def format_output(prediction):
|
59 |
+
return prediction.replace('<0x0A>', '\n')
|
60 |
+
|
61 |
+
# First model prediction: ko-deplot
|
62 |
+
def predict_model1(image):
|
63 |
+
images = [image]
|
64 |
+
inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
|
65 |
+
inputs = {k: v.to(device) for k, v in inputs.items()} # Move to GPU
|
66 |
+
|
67 |
+
model1.eval()
|
68 |
+
with torch.no_grad():
|
69 |
+
predictions = model1.generate(**inputs, max_new_tokens=4096)
|
70 |
+
outputs = [processor1.decode(pred, skip_special_tokens=True) for pred in predictions]
|
71 |
+
|
72 |
+
formatted_output = format_output(outputs[0])
|
73 |
+
return formatted_output
|
74 |
+
|
75 |
+
# Set your OpenAI API key
|
76 |
+
openai.api_key = "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"
|
77 |
+
|
78 |
+
# Function to encode the image as base64
|
79 |
+
def encode_image(image_path):
|
80 |
+
with open(image_path, "rb") as image_file:
|
81 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
82 |
+
|
83 |
+
# Second model prediction: gpt-4o-mini
|
84 |
+
def predict_model2(image):
|
85 |
+
# Encode the uploaded image to base64
|
86 |
+
image_data = encode_image(image)
|
87 |
+
|
88 |
+
# Prepare the request content
|
89 |
+
response = openai.ChatCompletion.create(
|
90 |
+
model="gpt-4o-mini",
|
91 |
+
messages=[
|
92 |
+
{
|
93 |
+
"role": "user",
|
94 |
+
"content": [
|
95 |
+
{
|
96 |
+
"type": "text",
|
97 |
+
"text": "please extract chart title and chart data manually and present them as a table. you should only provide title and table without adding any additional comments such as **Chart Title:** ."
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"type": "image_url",
|
101 |
+
"image_url": {
|
102 |
+
"url": f"data:image/jpeg;base64,{image_data}"
|
103 |
+
}
|
104 |
+
}
|
105 |
+
]
|
106 |
+
}
|
107 |
+
]
|
108 |
+
)
|
109 |
+
|
110 |
+
# Return the table data from the response
|
111 |
+
return response.choices[0]["message"]["content"]
|
112 |
+
|
113 |
+
def ko_deplot_convert_to_dataframe(label_table_str): #function that converts text generated by ko-deplot to pandas dataframe
|
114 |
+
lines = label_table_str.strip().split("\n")
|
115 |
+
data=[]
|
116 |
+
title= lines[0].split(" | ")[1]
|
117 |
+
|
118 |
+
if(len(lines[1].split("|")) == len(lines[2].split("|"))):
|
119 |
+
headers=lines[1].split(" | ")
|
120 |
+
for line in lines[2:]:
|
121 |
+
data.append(line.split(" | "))
|
122 |
+
df = pd.DataFrame(data, columns=headers)
|
123 |
+
return df, title
|
124 |
+
else:
|
125 |
+
legend_row=lines[1].split("|")
|
126 |
+
legend_row.insert(0," ")
|
127 |
+
for line in lines[2:]:
|
128 |
+
data.append(line.split(" | "))
|
129 |
+
df = pd.DataFrame(data, columns=legend_row)
|
130 |
+
return df, title
|
131 |
+
|
132 |
+
def gpt_convert_to_dataframe(table_text): #function that converts text generated by gpt to pandas dataframe
|
133 |
+
try:
|
134 |
+
# Split the text into lines
|
135 |
+
lines = table_text.strip().split("\n")
|
136 |
+
title=lines[0]
|
137 |
+
lines.pop(1)
|
138 |
+
lines.pop(2)
|
139 |
+
# Process the remaining lines to create the DataFrame
|
140 |
+
data = [line.split("|")[1:-1] for line in lines[1:]] # Split by | and remove empty first/last items
|
141 |
+
dataframe = pd.DataFrame(data[1:], columns=[col.strip() for col in data[0]]) # Use the first row as headers
|
142 |
+
|
143 |
+
return dataframe, title
|
144 |
+
except Exception as e:
|
145 |
+
return f"Error converting table to DataFrame: {e}"
|
146 |
+
|
147 |
+
def real_time_check(image_file):
|
148 |
+
image = Image.open(image_file)
|
149 |
+
ko_deplot_generated_txt = predict_model1(image)
|
150 |
+
parts=ko_deplot_generated_txt.split("\n")
|
151 |
+
del parts[-1]
|
152 |
+
ko_deplot_generated_txt="\n".join(parts)
|
153 |
+
gpt_generated_txt=predict_model2(image_file)
|
154 |
+
try:
|
155 |
+
ko_deplot_generated_df, ko_deplot_generated_title=ko_deplot_convert_to_dataframe(ko_deplot_generated_txt)
|
156 |
+
gpt_generated_df, gpt_generated_title=gpt_convert_to_dataframe(gpt_generated_txt)
|
157 |
+
return gr.DataFrame(ko_deplot_generated_df, label= ko_deplot_generated_title), gr.DataFrame(gpt_generated_df, label= gpt_generated_title), None,None,0
|
158 |
+
except Exception as e:
|
159 |
+
return None,None,ko_deplot_generated_txt,gpt_generated_txt,1
|
160 |
+
|
161 |
+
flag = 0 #flag to check whether exception happens or not. if flag is 1, it means that exception(generated txt cannot be converted to pandas dataframe) happens.
|
162 |
+
def inference(image_uploader,mode_selector):
|
163 |
+
if(mode_selector=="νμΌ μ
λ‘λ"):
|
164 |
+
ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_uploader)
|
165 |
+
if flag==1:
|
166 |
+
return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
|
167 |
+
else:
|
168 |
+
return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False)
|
169 |
+
else:
|
170 |
+
ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_files[current_image_index])
|
171 |
+
if flag==1:
|
172 |
+
return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
|
173 |
+
else:
|
174 |
+
return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False)
|
175 |
+
|
176 |
+
def toggle_model(selected_models,flag):
|
177 |
+
# Create a visibility list initialized to False for all components
|
178 |
+
visibility = [False] * 6
|
179 |
+
# Update visibility based on the selected models
|
180 |
+
if "VAIV_DePlot" in selected_models:
|
181 |
+
visibility[4]= True
|
182 |
+
if flag:
|
183 |
+
visibility[2]= True
|
184 |
+
else:
|
185 |
+
visibility[0]= True
|
186 |
+
if "gpt-4o-mini" in selected_models:
|
187 |
+
visibility[5]= True
|
188 |
+
if flag:
|
189 |
+
visibility[3]= True
|
190 |
+
else:
|
191 |
+
visibility[1]= True
|
192 |
+
if "all" in selected_models:
|
193 |
+
visibility[4]=True
|
194 |
+
visibility[5]=True
|
195 |
+
if flag:
|
196 |
+
visibility[2]= True
|
197 |
+
visibility[3]= True
|
198 |
+
else:
|
199 |
+
visibility[0]= True
|
200 |
+
visibility[1]= True
|
201 |
+
# Return gr.update for each component with the corresponding visibility status
|
202 |
+
return tuple(gr.update(visible=v) for v in visibility)
|
203 |
+
|
204 |
+
def toggle_mode(mode):
|
205 |
+
if mode == "νμΌ μ
λ‘λ":
|
206 |
+
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
207 |
+
else:
|
208 |
+
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
209 |
+
|
210 |
+
def display_image(image_file):
|
211 |
+
image=Image.open(image_file)
|
212 |
+
return image, os.path.basename(image_file)
|
213 |
+
|
214 |
+
# Function to display the images in the folder sequentially
|
215 |
+
image_files = []
|
216 |
+
current_image_index = 0
|
217 |
+
image_files_cnt=0
|
218 |
+
|
219 |
+
def display_folder_images(image_file_path_list):
|
220 |
+
global image_files, current_image_index,image_files_cnt
|
221 |
+
image_files = image_file_path_list
|
222 |
+
image_files_cnt=len(image_files)
|
223 |
+
current_image_index = 0
|
224 |
+
if image_files:
|
225 |
+
return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=False), gr.update(interactive=True)
|
226 |
+
return None, "No images found"
|
227 |
+
|
228 |
+
|
229 |
+
def next_image():
|
230 |
+
global current_image_index
|
231 |
+
if image_files:
|
232 |
+
current_image_index = (current_image_index + 1)
|
233 |
+
prev_disabled = current_image_index == 0
|
234 |
+
next_disabled = current_image_index == (len(image_files) - 1)
|
235 |
+
return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
|
236 |
+
return None, "No images found"
|
237 |
+
|
238 |
+
def prev_image():
|
239 |
+
global current_image_index
|
240 |
+
if image_files:
|
241 |
+
current_image_index = (current_image_index - 1)
|
242 |
+
prev_disabled = current_image_index == 0
|
243 |
+
next_disabled = current_image_index == (len(image_files) - 1)
|
244 |
+
return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
|
245 |
+
return None, "No images found"
|
246 |
+
|
247 |
+
css = """
|
248 |
+
.dataframe-class {
|
249 |
+
overflow-y: auto !important; /* μ€ν¬λ‘€μ κ°λ₯νκ² */
|
250 |
+
height: 250px
|
251 |
+
}
|
252 |
+
"""
|
253 |
+
|
254 |
+
with gr.Blocks(css=css) as iface:
|
255 |
+
with gr.Row():
|
256 |
+
gr.Markdown("<h1 style='text-align: center;'>SKKU-VAIV Automatic chart understanding evaluation tool</h1>")
|
257 |
+
gr.Markdown("<hr style='border: 1px solid #ddd;' />")
|
258 |
+
with gr.Row():
|
259 |
+
with gr.Column():
|
260 |
+
mode_selector = gr.Radio(["νμΌ μ
λ‘λ", "ν΄λ μ
λ‘λ"], label="Upload Mode", value="νμΌ μ
λ‘λ")
|
261 |
+
image_uploader = gr.File(file_count="single", file_types=["image"], visible=True)
|
262 |
+
folder_uploader = gr.File(file_count="directory", file_types=["image"], visible=False, height=50)
|
263 |
+
model_type=gr.Dropdown(["VAIV_DePlot","gpt-4o-mini","all"],value="VAIV_DePlot",label="model",multiselect=True)
|
264 |
+
image_displayer = gr.Image(visible=True)
|
265 |
+
image_name = gr.Text("", visible=True)
|
266 |
+
with gr.Row():
|
267 |
+
prev_button = gr.Button("μ΄μ ", visible=False, interactive=False)
|
268 |
+
next_button = gr.Button("λ€μ", visible=False, interactive=False)
|
269 |
+
inference_button = gr.Button("μΆλ‘ ")
|
270 |
+
with gr.Column():
|
271 |
+
md1 = gr.Markdown("# VAIV_DePlot Inference Result")
|
272 |
+
ko_deplot_generated_df = gr.DataFrame(visible=True, elem_classes="dataframe-class")
|
273 |
+
ko_deplot_generated_txt = gr.Text(visible=False)
|
274 |
+
with gr.Column():
|
275 |
+
md2 = gr.Markdown("# gpt-4o-mini Inference Result", visible=False)
|
276 |
+
gpt_generated_df = gr.DataFrame(visible=False, elem_classes="dataframe-class")
|
277 |
+
gpt_generated_txt = gr.Text(visible=False)
|
278 |
+
#label_df = gr.DataFrame(visible=False, label="Ground Truth Table", elem_classes="dataframe-class",scale=1)
|
279 |
+
|
280 |
+
model_type.change(
|
281 |
+
toggle_model,
|
282 |
+
inputs=[model_type, gr.State(flag)],
|
283 |
+
outputs=[ko_deplot_generated_df,gpt_generated_df,ko_deplot_generated_txt,gpt_generated_txt,md1,md2]
|
284 |
+
)
|
285 |
+
|
286 |
+
mode_selector.change(
|
287 |
+
toggle_mode,
|
288 |
+
inputs=[mode_selector],
|
289 |
+
outputs=[image_uploader, folder_uploader, prev_button, next_button]
|
290 |
+
)
|
291 |
+
|
292 |
+
image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
|
293 |
+
folder_uploader.upload(display_folder_images, inputs=[folder_uploader], outputs=[image_displayer, image_name, prev_button, next_button])
|
294 |
+
prev_button.click(prev_image, outputs=[image_displayer, image_name, prev_button, next_button])
|
295 |
+
next_button.click(next_image, outputs=[image_displayer, image_name, prev_button, next_button])
|
296 |
+
inference_button.click(inference,inputs=[image_uploader,mode_selector],outputs=[ko_deplot_generated_df, gpt_generated_df, ko_deplot_generated_txt, gpt_generated_txt])
|
297 |
+
|
298 |
+
if __name__ == "__main__":
|
299 |
+
iface.launch(share=True)
|