File size: 4,650 Bytes
9d3f1f1
 
 
 
 
 
0980890
 
eee8d68
9d3f1f1
 
 
 
 
 
 
 
 
0980890
eee8d68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d3f1f1
 
 
eee8d68
 
9d3f1f1
 
eee8d68
9d3f1f1
eee8d68
9d3f1f1
 
0980890
9d3f1f1
 
 
eee8d68
 
9d3f1f1
 
eee8d68
0980890
9d3f1f1
eee8d68
9d3f1f1
 
 
 
 
 
 
 
 
 
 
ae3e99f
9d3f1f1
eee8d68
 
9d3f1f1
 
 
 
 
 
ae3e99f
9d3f1f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eee8d68
9d3f1f1
 
 
eee8d68
9d3f1f1
 
 
 
 
eee8d68
 
9d3f1f1
 
 
eee8d68
9d3f1f1
 
 
 
 
 
eee8d68
 
 
 
9d3f1f1
 
 
eee8d68
9d3f1f1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import sys
import logging
from datetime import datetime
import torch
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from PIL import Image

# 配置日志格式
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('app.log')
    ]
)
logger = logging.getLogger(__name__)

def print_section(title, char='='):
    """打印格式化的章节标题"""
    print(f"\n{char * 50}")
    print(f"{title.center(50)}")
    print(f"{char * 50}\n")

def print_table(data):
    """格式化打印表格数据"""
    if not data:
        print("No data available")
        return
        
    # 计算每列的最大宽度
    col_widths = []
    for i in range(len(data[0])):
        col_width = max(len(str(row[i])) for row in data)
        col_widths.append(col_width)
    
    # 打印表头
    header = data[0]
    header_str = " | ".join(str(header[i]).ljust(col_widths[i]) for i in range(len(header)))
    print(header_str)
    print("-" * len(header_str))
    
    # 打印数据行
    for row in data[1:]:
        row_str = " | ".join(str(row[i]).ljust(col_widths[i]) for i in range(len(row)))
        print(row_str)

class ChartAnalyzer:
    def __init__(self):
        try:
            print_section("初始化模型")
            print("正在加载模型和处理器...")
            self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
            self.processor = AutoProcessor.from_pretrained("google/deplot")
            print("✓ 模型加载完成")
        except Exception as e:
            print("✗ 模型加载失败")
            logger.error(f"Error initializing model: {str(e)}")
            raise

    def process_image(self, image_path, prompt=None):
        """处理图片并生成数据表格"""
        try:
            print_section("图片处理", char='-')
            
            # 验证文件存在
            if not os.path.exists(image_path):
                raise FileNotFoundError(f"找不到图片文件: {image_path}")

            # 打开并处理图片
            print(f"正在处理图片: {image_path}")
            image = Image.open(image_path)
            
            # 准备输入
            if prompt is None:
                prompt = "Generate underlying data table of the figure below:"
            
            inputs = self.processor(
                images=image, 
                text=prompt,
                return_tensors="pt"
            )

            # 生成预测
            print("\n正在生成数据分析...")
            with torch.no_grad():
                predictions = self.model.generate(
                    **inputs,
                    max_new_tokens=512,
                    num_beams=4,
                    length_penalty=1.0
                )

            # 解码预测结果
            raw_output = self.processor.decode(predictions[0], skip_special_tokens=True)
            
            # 处理结果
            split_by_newline = raw_output.split("<0x0A>")
            result_array = []
            for item in split_by_newline:
                if item.strip():  # 跳过空行
                    result_array.append([x.strip() for x in item.split("|")])

            # 保存结果
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_file = f'results_{timestamp}.log'
            
            with open(output_file, mode='w', encoding='utf-8') as file:
                for row in result_array:
                    file.write(" | ".join(row) + "\n")
            
            print(f"\n✓ 结果已保存至: {output_file}")
            return result_array

        except Exception as e:
            print("\n✗ 处理失败")
            logger.error(f"Error processing image: {str(e)}")
            raise

def main():
    try:
        print_section("图表数据提取系统", char='*')
        
        # 创建分析器实例
        analyzer = ChartAnalyzer()
        
        # 指定图片路径
        image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
        
        # 处理图片
        results = analyzer.process_image(image_path)
        
        # 打印结果
        print_section("分析结果")
        print_table(results)
        
        print_section("处理完成", char='*')
            
    except Exception as e:
        logger.error(f"Application error: {str(e)}")
        print("\n✗ 程序执行出错,请查看日志获取详细信息")
        raise

if __name__ == "__main__":
    main()