Sangjun2 commited on
Commit
67f1204
Β·
verified Β·
1 Parent(s): 27f2ca7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -0
app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, Pix2StructForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, Pix2StructProcessor, BartConfig,ViTConfig,VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, AutoTokenizer, AutoModel
3
+ from PIL import Image
4
+ import torch
5
+ import warnings
6
+ import re
7
+ import json
8
+ import os
9
+ import numpy as np
10
+ import pandas as pd
11
+ from tqdm import tqdm
12
+ import argparse
13
+ from scipy import optimize
14
+ from typing import Optional
15
+ import dataclasses
16
+ import editdistance
17
+ import itertools
18
+ import sys
19
+ import time
20
+ import logging
21
+ import subprocess
22
+ import spaces
23
+ import openai
24
+ import base64
25
+ from io import StringIO
26
+
27
+ # Git LFS pull λͺ…λ Ήμ–΄ μ‹€ν–‰
28
+ result = subprocess.run(['git', 'lfs', 'pull'], capture_output=True, text=True)
29
+
30
+ # λͺ…λ Ήμ–΄ μ‹€ν–‰ κ²°κ³Ό 좜λ ₯ (선택 사항)
31
+ if result.returncode == 0:
32
+ print("LFS 파일이 μ„±κ³΅μ μœΌλ‘œ λ‹€μš΄λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
33
+ else:
34
+ print(f"였λ₯˜ λ°œμƒ: {result.stderr}")
35
+
36
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
37
+ logger = logging.getLogger()
38
+
39
+ warnings.filterwarnings('ignore')
40
+ MAX_PATCHES = 512
41
+ # Load the models and processor
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+ # Paths to the models
45
+ ko_deplot_model_path = './deplot_model_ver_24.11.21_korean_only(exclude NUUA)_epoch1.bin'
46
+
47
+ # Load first model ko-deplot
48
+ def load_model1():
49
+ processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
50
+ model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
51
+ model1.load_state_dict(torch.load(ko_deplot_model_path, map_location="cpu"))
52
+ model1.to(torch.device("cuda"))
53
+ return processor1, model1
54
+
55
+ processor1, model1 = load_model1()
56
+
57
+ # Function to format output
58
+ def format_output(prediction):
59
+ return prediction.replace('<0x0A>', '\n')
60
+
61
+ # First model prediction: ko-deplot
62
+ def predict_model1(image):
63
+ images = [image]
64
+ inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
65
+ inputs = {k: v.to(device) for k, v in inputs.items()} # Move to GPU
66
+
67
+ model1.eval()
68
+ with torch.no_grad():
69
+ predictions = model1.generate(**inputs, max_new_tokens=4096)
70
+ outputs = [processor1.decode(pred, skip_special_tokens=True) for pred in predictions]
71
+
72
+ formatted_output = format_output(outputs[0])
73
+ return formatted_output
74
+
75
+ # Set your OpenAI API key
76
+ openai.api_key = "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"
77
+
78
+ # Function to encode the image as base64
79
+ def encode_image(image_path):
80
+ with open(image_path, "rb") as image_file:
81
+ return base64.b64encode(image_file.read()).decode("utf-8")
82
+
83
+ # Second model prediction: gpt-4o-mini
84
+ def predict_model2(image):
85
+ # Encode the uploaded image to base64
86
+ image_data = encode_image(image)
87
+
88
+ # Prepare the request content
89
+ response = openai.ChatCompletion.create(
90
+ model="gpt-4o-mini",
91
+ messages=[
92
+ {
93
+ "role": "user",
94
+ "content": [
95
+ {
96
+ "type": "text",
97
+ "text": "please extract chart title and chart data manually and present them as a table. you should only provide title and table without adding any additional comments such as **Chart Title:** ."
98
+ },
99
+ {
100
+ "type": "image_url",
101
+ "image_url": {
102
+ "url": f"data:image/jpeg;base64,{image_data}"
103
+ }
104
+ }
105
+ ]
106
+ }
107
+ ]
108
+ )
109
+
110
+ # Return the table data from the response
111
+ return response.choices[0]["message"]["content"]
112
+
113
+ def ko_deplot_convert_to_dataframe(label_table_str): #function that converts text generated by ko-deplot to pandas dataframe
114
+ lines = label_table_str.strip().split("\n")
115
+ data=[]
116
+ title= lines[0].split(" | ")[1]
117
+
118
+ if(len(lines[1].split("|")) == len(lines[2].split("|"))):
119
+ headers=lines[1].split(" | ")
120
+ for line in lines[2:]:
121
+ data.append(line.split(" | "))
122
+ df = pd.DataFrame(data, columns=headers)
123
+ return df, title
124
+ else:
125
+ legend_row=lines[1].split("|")
126
+ legend_row.insert(0," ")
127
+ for line in lines[2:]:
128
+ data.append(line.split(" | "))
129
+ df = pd.DataFrame(data, columns=legend_row)
130
+ return df, title
131
+
132
+ def gpt_convert_to_dataframe(table_text): #function that converts text generated by gpt to pandas dataframe
133
+ try:
134
+ # Split the text into lines
135
+ lines = table_text.strip().split("\n")
136
+ title=lines[0]
137
+ lines.pop(1)
138
+ lines.pop(2)
139
+ # Process the remaining lines to create the DataFrame
140
+ data = [line.split("|")[1:-1] for line in lines[1:]] # Split by | and remove empty first/last items
141
+ dataframe = pd.DataFrame(data[1:], columns=[col.strip() for col in data[0]]) # Use the first row as headers
142
+
143
+ return dataframe, title
144
+ except Exception as e:
145
+ return f"Error converting table to DataFrame: {e}"
146
+
147
+ def real_time_check(image_file):
148
+ image = Image.open(image_file)
149
+ ko_deplot_generated_txt = predict_model1(image)
150
+ parts=ko_deplot_generated_txt.split("\n")
151
+ del parts[-1]
152
+ ko_deplot_generated_txt="\n".join(parts)
153
+ gpt_generated_txt=predict_model2(image_file)
154
+ try:
155
+ ko_deplot_generated_df, ko_deplot_generated_title=ko_deplot_convert_to_dataframe(ko_deplot_generated_txt)
156
+ gpt_generated_df, gpt_generated_title=gpt_convert_to_dataframe(gpt_generated_txt)
157
+ return gr.DataFrame(ko_deplot_generated_df, label= ko_deplot_generated_title), gr.DataFrame(gpt_generated_df, label= gpt_generated_title), None,None,0
158
+ except Exception as e:
159
+ return None,None,ko_deplot_generated_txt,gpt_generated_txt,1
160
+
161
+ flag = 0 #flag to check whether exception happens or not. if flag is 1, it means that exception(generated txt cannot be converted to pandas dataframe) happens.
162
+ def inference(image_uploader,mode_selector):
163
+ if(mode_selector=="파일 μ—…λ‘œλ“œ"):
164
+ ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_uploader)
165
+ if flag==1:
166
+ return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
167
+ else:
168
+ return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False)
169
+ else:
170
+ ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_files[current_image_index])
171
+ if flag==1:
172
+ return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
173
+ else:
174
+ return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False)
175
+
176
+ def toggle_model(selected_models,flag):
177
+ # Create a visibility list initialized to False for all components
178
+ visibility = [False] * 6
179
+ # Update visibility based on the selected models
180
+ if "VAIV_DePlot" in selected_models:
181
+ visibility[4]= True
182
+ if flag:
183
+ visibility[2]= True
184
+ else:
185
+ visibility[0]= True
186
+ if "gpt-4o-mini" in selected_models:
187
+ visibility[5]= True
188
+ if flag:
189
+ visibility[3]= True
190
+ else:
191
+ visibility[1]= True
192
+ if "all" in selected_models:
193
+ visibility[4]=True
194
+ visibility[5]=True
195
+ if flag:
196
+ visibility[2]= True
197
+ visibility[3]= True
198
+ else:
199
+ visibility[0]= True
200
+ visibility[1]= True
201
+ # Return gr.update for each component with the corresponding visibility status
202
+ return tuple(gr.update(visible=v) for v in visibility)
203
+
204
+ def toggle_mode(mode):
205
+ if mode == "파일 μ—…λ‘œλ“œ":
206
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
207
+ else:
208
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
209
+
210
+ def display_image(image_file):
211
+ image=Image.open(image_file)
212
+ return image, os.path.basename(image_file)
213
+
214
+ # Function to display the images in the folder sequentially
215
+ image_files = []
216
+ current_image_index = 0
217
+ image_files_cnt=0
218
+
219
+ def display_folder_images(image_file_path_list):
220
+ global image_files, current_image_index,image_files_cnt
221
+ image_files = image_file_path_list
222
+ image_files_cnt=len(image_files)
223
+ current_image_index = 0
224
+ if image_files:
225
+ return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=False), gr.update(interactive=True)
226
+ return None, "No images found"
227
+
228
+
229
+ def next_image():
230
+ global current_image_index
231
+ if image_files:
232
+ current_image_index = (current_image_index + 1)
233
+ prev_disabled = current_image_index == 0
234
+ next_disabled = current_image_index == (len(image_files) - 1)
235
+ return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
236
+ return None, "No images found"
237
+
238
+ def prev_image():
239
+ global current_image_index
240
+ if image_files:
241
+ current_image_index = (current_image_index - 1)
242
+ prev_disabled = current_image_index == 0
243
+ next_disabled = current_image_index == (len(image_files) - 1)
244
+ return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
245
+ return None, "No images found"
246
+
247
+ css = """
248
+ .dataframe-class {
249
+ overflow-y: auto !important; /* μŠ€ν¬λ‘€μ„ κ°€λŠ₯ν•˜κ²Œ */
250
+ height: 250px
251
+ }
252
+ """
253
+
254
+ with gr.Blocks(css=css) as iface:
255
+ with gr.Row():
256
+ gr.Markdown("<h1 style='text-align: center;'>SKKU-VAIV Automatic chart understanding evaluation tool</h1>")
257
+ gr.Markdown("<hr style='border: 1px solid #ddd;' />")
258
+ with gr.Row():
259
+ with gr.Column():
260
+ mode_selector = gr.Radio(["파일 μ—…λ‘œλ“œ", "폴더 μ—…λ‘œλ“œ"], label="Upload Mode", value="파일 μ—…λ‘œλ“œ")
261
+ image_uploader = gr.File(file_count="single", file_types=["image"], visible=True)
262
+ folder_uploader = gr.File(file_count="directory", file_types=["image"], visible=False, height=50)
263
+ model_type=gr.Dropdown(["VAIV_DePlot","gpt-4o-mini","all"],value="VAIV_DePlot",label="model",multiselect=True)
264
+ image_displayer = gr.Image(visible=True)
265
+ image_name = gr.Text("", visible=True)
266
+ with gr.Row():
267
+ prev_button = gr.Button("이전", visible=False, interactive=False)
268
+ next_button = gr.Button("λ‹€μŒ", visible=False, interactive=False)
269
+ inference_button = gr.Button("μΆ”λ‘ ")
270
+ with gr.Column():
271
+ md1 = gr.Markdown("# VAIV_DePlot Inference Result")
272
+ ko_deplot_generated_df = gr.DataFrame(visible=True, elem_classes="dataframe-class")
273
+ ko_deplot_generated_txt = gr.Text(visible=False)
274
+ with gr.Column():
275
+ md2 = gr.Markdown("# gpt-4o-mini Inference Result", visible=False)
276
+ gpt_generated_df = gr.DataFrame(visible=False, elem_classes="dataframe-class")
277
+ gpt_generated_txt = gr.Text(visible=False)
278
+ #label_df = gr.DataFrame(visible=False, label="Ground Truth Table", elem_classes="dataframe-class",scale=1)
279
+
280
+ model_type.change(
281
+ toggle_model,
282
+ inputs=[model_type, gr.State(flag)],
283
+ outputs=[ko_deplot_generated_df,gpt_generated_df,ko_deplot_generated_txt,gpt_generated_txt,md1,md2]
284
+ )
285
+
286
+ mode_selector.change(
287
+ toggle_mode,
288
+ inputs=[mode_selector],
289
+ outputs=[image_uploader, folder_uploader, prev_button, next_button]
290
+ )
291
+
292
+ image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
293
+ folder_uploader.upload(display_folder_images, inputs=[folder_uploader], outputs=[image_displayer, image_name, prev_button, next_button])
294
+ prev_button.click(prev_image, outputs=[image_displayer, image_name, prev_button, next_button])
295
+ next_button.click(next_image, outputs=[image_displayer, image_name, prev_button, next_button])
296
+ inference_button.click(inference,inputs=[image_uploader,mode_selector],outputs=[ko_deplot_generated_df, gpt_generated_df, ko_deplot_generated_txt, gpt_generated_txt])
297
+
298
+ if __name__ == "__main__":
299
+ iface.launch(share=True)