File size: 3,554 Bytes
9d3f1f1
 
 
 
 
 
0980890
 
9d3f1f1
 
 
 
 
 
 
 
 
 
0980890
9d3f1f1
 
 
 
 
 
 
 
 
 
0980890
9d3f1f1
 
 
 
 
 
0980890
9d3f1f1
 
 
 
 
 
 
 
 
 
 
 
 
ae3e99f
9d3f1f1
 
 
 
 
 
 
 
 
ae3e99f
9d3f1f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import sys
import logging
from datetime import datetime
import torch
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from PIL import Image

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('app.log')
    ]
)
logger = logging.getLogger(__name__)

class ChartAnalyzer:
    def __init__(self):
        try:
            logger.info("Initializing model and processor...")
            self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
            self.processor = AutoProcessor.from_pretrained("google/deplot")
            logger.info("Model and processor initialized successfully")
        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise

    def process_image(self, image_path, prompt=None):
        """处理图片并生成数据表格"""
        try:
            # 验证文件存在
            if not os.path.exists(image_path):
                raise FileNotFoundError(f"Image file not found: {image_path}")

            # 打开并处理图片
            logger.info(f"Processing image: {image_path}")
            image = Image.open(image_path)
            
            # 准备输入
            if prompt is None:
                prompt = "Generate underlying data table of the figure below:"
            
            inputs = self.processor(
                images=image, 
                text=prompt,
                return_tensors="pt"
            )

            # 生成预测
            logger.info("Generating predictions...")
            with torch.no_grad():  # 提高性能并减少内存使用
                predictions = self.model.generate(
                    **inputs,
                    max_new_tokens=512,
                    num_beams=4,
                    length_penalty=1.0
                )

            # 解码预测结果
            raw_output = self.processor.decode(predictions[0], skip_special_tokens=True)
            
            # 处理结果
            split_by_newline = raw_output.split("<0x0A>")
            result_array = []
            for item in split_by_newline:
                if item.strip():  # 跳过空行
                    result_array.append([x.strip() for x in item.split("|")])

            # 保存结果
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_file = f'results_{timestamp}.log'
            
            with open(output_file, mode='w', encoding='utf-8') as file:
                for row in result_array:
                    file.write(" | ".join(row) + "\n")
            
            logger.info(f"Results saved to {output_file}")
            return result_array

        except Exception as e:
            logger.error(f"Error processing image: {str(e)}")
            raise

def main():
    try:
        # 创建分析器实例
        analyzer = ChartAnalyzer()
        
        # 指定图片路径(在Space中使用上传的图片路径)
        image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
        
        # 处理图片
        results = analyzer.process_image(image_path)
        
        # 打印结果
        print("\nAnalysis Results:")
        for row in results:
            print(" | ".join(row))
            
    except Exception as e:
        logger.error(f"Application error: {str(e)}")
        raise

if __name__ == "__main__":
    main()