File size: 13,744 Bytes
67f1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a06442d
67f1204
7f63b36
fbe889f
67f1204
7f63b36
fbe889f
 
 
 
67f1204
7f63b36
 
67f1204
7f63b36
 
67f1204
 
 
 
ca13eac
 
 
 
 
 
67f1204
 
 
 
 
ca13eac
67f1204
 
 
 
 
 
 
 
 
 
364867a
67f1204
 
3661fce
67f1204
 
 
 
 
 
 
 
 
 
 
1065b75
67f1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac6cb27
67f1204
eaa60bb
67f1204
 
 
 
 
 
eaa60bb
67f1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a751e10
67f1204
a751e10
67f1204
 
 
 
a751e10
67f1204
 
 
a751e10
67f1204
a751e10
67f1204
 
 
a751e10
67f1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb42ef0
67f1204
fb42ef0
67f1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb42ef0
67f1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb42ef0
 
 
67f1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb42ef0
5d570d0
67f1204
 
 
 
 
 
 
 
 
 
 
7c6d735
83663f0
67f1204
 
 
 
 
 
 
 
 
 
 
 
fb42ef0
67f1204
 
 
fb42ef0
67f1204
 
fb42ef0
67f1204
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import gradio as gr
from transformers import AutoProcessor, Pix2StructForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, Pix2StructProcessor, BartConfig,ViTConfig,VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, AutoTokenizer, AutoModel
from PIL import Image  
import torch
import warnings
import re
import json
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import argparse
from scipy import optimize
from typing import Optional
import dataclasses
import editdistance
import itertools
import sys
import time
import logging
import subprocess
import spaces
import openai
import base64
from io import StringIO
from huggingface_hub import hf_hub_download

#Git LFS pull λͺ…λ Ήμ–΄ μ‹€ν–‰
#result = subprocess.run(['git', 'lfs', 'pull'], capture_output=True, text=True)

#λͺ…λ Ήμ–΄ μ‹€ν–‰ κ²°κ³Ό 좜λ ₯ (선택 사항)
#if result.returncode == 0:
#    print("LFS 파일이 μ„±κ³΅μ μœΌλ‘œ λ‹€μš΄λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
#else:
#    print(f"였λ₯˜ λ°œμƒ: {result.stderr}")

#logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
#logger = logging.getLogger()

#warnings.filterwarnings('ignore')
#MAX_PATCHES = 512
# Load the models and processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths to the models
#ko_deplot_model_path = './deplot_model_ver_24.11.21_korean_only(exclude NUUA)_epoch1.bin'

file_path = hf_hub_download(
    repo_id="Sangjun2/skku-Deplot", 
    filename="deplot_model_ver_24.11.21_korean_only(exclude NUUA)_epoch3.bin"
)

# Load first model ko-deplot
def load_model1():
    processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
    model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
    model1.load_state_dict(torch.load(file_path, map_location="cpu"))
    model1.to(torch.device("cuda"))
    return processor1, model1

processor1, model1 = load_model1()

# Function to format output
def format_output(prediction):
    return prediction.replace('<0x0A>', '\n')

# First model prediction: ko-deplot
@spaces.GPU(enable_queue=True,duration=100)
def predict_model1(image):
    images = [image]
    inputs = processor1(images=images, text="Generate underlying data table of the figure below:", return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}  # Move to GPU

    model1.eval()
    with torch.no_grad():
        predictions = model1.generate(**inputs, max_new_tokens=4096)
    outputs = [processor1.decode(pred, skip_special_tokens=True) for pred in predictions]

    formatted_output = format_output(outputs[0])
    return formatted_output

# Set your OpenAI API key
openai.api_key = os.getenv('gpt_api_key')

# Function to encode the image as base64
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

# Second model prediction: gpt-4o-mini
def predict_model2(image):
    # Encode the uploaded image to base64
    image_data = encode_image(image)

    # Prepare the request content
    response = openai.ChatCompletion.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "please extract chart title and chart data manually and present them as a table. you should only provide title and table without adding any additional comments such as **Chart Title:** ."
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image_data}"
                        }
                    }
                ]
            }
        ]
    )

    # Return the table data from the response
    return response.choices[0]["message"]["content"]

def ko_deplot_convert_to_dataframe(label_table_str): #function that converts text generated by ko-deplot to pandas dataframe
    lines = label_table_str.strip().split("\n")
    data=[]
    title= lines[0].split(" | ")[1]

    if(len(lines[1].split("|")) == len(lines[2].split("|"))):
      headers=lines[1].split("|")
      for line in lines[2:]:
          data.append(line.split("|"))
      df = pd.DataFrame(data, columns=headers)
      return df, title
    else:
      legend_row=lines[1].split("|")
      legend_row.insert(0," ")
      for line in lines[2:]:
          data.append(line.split("|"))
      df = pd.DataFrame(data, columns=legend_row)
      return df, title

def gpt_convert_to_dataframe(table_text): #function that converts text generated by gpt to pandas dataframe
    try:
        # Split the text into lines
        lines = table_text.strip().split("\n")
        title=lines[0]
        lines.pop(1)
        lines.pop(2)
        # Process the remaining lines to create the DataFrame
        data = [line.split("|")[1:-1] for line in lines[1:]]  # Split by | and remove empty first/last items
        dataframe = pd.DataFrame(data[1:], columns=[col.strip() for col in data[0]])  # Use the first row as headers
        
        return dataframe, title
    except Exception as e:
        return f"Error converting table to DataFrame: {e}"

def real_time_check(image_file):
    image = Image.open(image_file)
    ko_deplot_generated_txt = predict_model1(image)
    parts=ko_deplot_generated_txt.split("\n")
    del parts[-1]
    ko_deplot_generated_txt="\n".join(parts)
    gpt_generated_txt=predict_model2(image_file)
    try:
        ko_deplot_generated_df, ko_deplot_generated_title=ko_deplot_convert_to_dataframe(ko_deplot_generated_txt)
        gpt_generated_df, gpt_generated_title=gpt_convert_to_dataframe(gpt_generated_txt)
        return ko_deplot_generated_df, ko_deplot_generated_title, gpt_generated_df, gpt_generated_title, None,None,0
    except Exception as e:
        return None,None,None,None,ko_deplot_generated_txt,gpt_generated_txt,1
    
flag = 0 #flag to check whether exception happens or not. if flag is 1, it means that exception(generated txt cannot be converted to pandas dataframe) happens. 
def inference(image_uploader,mode_selector):
    if(mode_selector=="파일 μ—…λ‘œλ“œ"):
        ko_deplot_generated_df, ko_deplot_generated_title, gpt_generated_df, gpt_generated_title, ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_uploader)
        if flag==1:
            return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
        else:
            return gr.DataFrame(ko_deplot_generated_df, label = ko_deplot_generated_title, visible=True), gr.DataFrame(gpt_generated_df, label = gpt_generated_title, visible=True), gr.update(visible=False),gr.update(visible=False)
    else:
        ko_deplot_generated_df, ko_deplot_generated_title, gpt_generated_df, gpt_generated_title, ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_files[current_image_index])
        if flag==1:
            return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
        else:
            return gr.DataFrame(ko_deplot_generated_df, label = ko_deplot_generated_title, visible=True), gr.DataFrame(gpt_generated_df, label = gpt_generated_title, visible=True), gr.update(visible=False),gr.update(visible=False)
        
def toggle_model(selected_models,flag):
    # Create a visibility list initialized to False for all components
    visibility = [False] * 6
    # Update visibility based on the selected models
    if "VAIV_DePlot" in selected_models:
        visibility[4]= True
        if flag:
            visibility[2]=  True
        else:
            visibility[0]= True
    if "gpt-4o-mini" in selected_models:
        visibility[5]= True
        if flag:
            visibility[3]=  True
        else:
            visibility[1]= True
    if "all" in selected_models:
        visibility[4]=True
        visibility[5]=True
        if flag:
            visibility[2]=  True
            visibility[3]= True
        else:
            visibility[0]= True
            visibility[1]= True
    # Return gr.update for each component with the corresponding visibility status
    return tuple(gr.update(visible=v) for v in visibility)

def toggle_mode(mode):
        if mode == "파일 μ—…λ‘œλ“œ":
            return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) , gr.update(visible=False)
        else:
            return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)

def display_image(image_file):
    image=Image.open(image_file)
    return image, os.path.basename(image_file)

# Function to display the images in the folder sequentially
image_files = []
current_image_index = 0
image_files_cnt=0

def display_folder_images(image_file_path_list):
    global image_files, current_image_index,image_files_cnt
    image_files = image_file_path_list
    image_files_cnt=len(image_files)
    current_image_index = 0
    if image_files:
        return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=False), gr.update(interactive=True), gr.update(visible = True), gr.update(visible = False)
    return None, "No images found"
    

def next_image():
    global current_image_index
    if image_files:
        current_image_index = (current_image_index + 1)
        prev_disabled = current_image_index == 0
        next_disabled = current_image_index == (len(image_files) - 1)
        return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
    return None, "No images found"

def prev_image():
    global current_image_index
    if image_files:
        current_image_index = (current_image_index - 1)
        prev_disabled = current_image_index == 0
        next_disabled = current_image_index == (len(image_files) - 1)
        return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
    return None, "No images found"

def folder_reupload():
    return gr.update(visible=False), gr.update(visible=True)

css = """
.dataframe-class {
    overflow-y: auto !important; /* μŠ€ν¬λ‘€μ„ κ°€λŠ₯ν•˜κ²Œ */
    height: 250px
}
"""

with gr.Blocks(css=css) as iface:
    with gr.Row():
        gr.Markdown("<h1 style='text-align: center;'>SKKU-VAIV Automatic chart understanding evaluation tool</h1>")
    gr.Markdown("<hr style='border: 1px solid #ddd;' />")
    with gr.Row():
        with gr.Column():
            mode_selector = gr.Radio(["파일 μ—…λ‘œλ“œ", "폴더 μ—…λ‘œλ“œ"], label="Upload Mode", value="파일 μ—…λ‘œλ“œ")
            image_uploader = gr.File(file_count="single", file_types=["image"], visible=True)
            folder_uploader = gr.File(file_count="directory", file_types=["image"], visible=False, height=50)
            folder_reupload_button = gr.Button("폴더 μ—…λ‘œλ“œ", visible=False)
            model_type=gr.Dropdown(["VAIV_DePlot","gpt-4o-mini","all"],value="all",label="model",multiselect=True)
            image_displayer = gr.Image(visible=True)
            image_name = gr.Text("", visible=True)
            with gr.Row():
                prev_button = gr.Button("이전", visible=False, interactive=False)
                next_button = gr.Button("λ‹€μŒ", visible=False, interactive=False)
            inference_button = gr.Button("μΆ”λ‘ ")
        with gr.Column():
            md1 = gr.Markdown("# VAIV_DePlot Inference Result")
            ko_deplot_generated_df = gr.DataFrame(visible=True, elem_classes="dataframe-class")
            ko_deplot_generated_txt = gr.Text(visible=False)
        with gr.Column():    
            md2 = gr.Markdown("# gpt-4o-mini Inference Result", visible=True)
            gpt_generated_df = gr.DataFrame(visible=True, elem_classes="dataframe-class")
            gpt_generated_txt = gr.Text(visible=False)
            #label_df = gr.DataFrame(visible=False, label="Ground Truth Table", elem_classes="dataframe-class",scale=1)

    model_type.change(
                        toggle_model,
                        inputs=[model_type, gr.State(flag)],
                        outputs=[ko_deplot_generated_df,gpt_generated_df,ko_deplot_generated_txt,gpt_generated_txt,md1,md2]
                        )

    mode_selector.change(
        toggle_mode,
        inputs=[mode_selector],
        outputs=[image_uploader, folder_uploader, prev_button, next_button, folder_reupload_button]
    )

    image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
    folder_uploader.upload(display_folder_images, inputs=[folder_uploader], outputs=[image_displayer, image_name, prev_button, next_button, folder_reupload_button, folder_uploader])
    prev_button.click(prev_image, outputs=[image_displayer, image_name, prev_button, next_button])
    next_button.click(next_image, outputs=[image_displayer, image_name, prev_button, next_button])
    folder_reupload_button.click(folder_reupload, outputs =[folder_reupload_button, folder_uploader])
    inference_button.click(inference,inputs=[image_uploader,mode_selector],outputs=[ko_deplot_generated_df, gpt_generated_df, ko_deplot_generated_txt, gpt_generated_txt])

    if __name__ == "__main__":
        iface.launch(share=True)