File size: 3,554 Bytes
9d3f1f1 0980890 9d3f1f1 0980890 9d3f1f1 0980890 9d3f1f1 0980890 9d3f1f1 ae3e99f 9d3f1f1 ae3e99f 9d3f1f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
import sys
import logging
from datetime import datetime
import torch
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from PIL import Image
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('app.log')
]
)
logger = logging.getLogger(__name__)
class ChartAnalyzer:
def __init__(self):
try:
logger.info("Initializing model and processor...")
self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
self.processor = AutoProcessor.from_pretrained("google/deplot")
logger.info("Model and processor initialized successfully")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def process_image(self, image_path, prompt=None):
"""处理图片并生成数据表格"""
try:
# 验证文件存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
# 打开并处理图片
logger.info(f"Processing image: {image_path}")
image = Image.open(image_path)
# 准备输入
if prompt is None:
prompt = "Generate underlying data table of the figure below:"
inputs = self.processor(
images=image,
text=prompt,
return_tensors="pt"
)
# 生成预测
logger.info("Generating predictions...")
with torch.no_grad(): # 提高性能并减少内存使用
predictions = self.model.generate(
**inputs,
max_new_tokens=512,
num_beams=4,
length_penalty=1.0
)
# 解码预测结果
raw_output = self.processor.decode(predictions[0], skip_special_tokens=True)
# 处理结果
split_by_newline = raw_output.split("<0x0A>")
result_array = []
for item in split_by_newline:
if item.strip(): # 跳过空行
result_array.append([x.strip() for x in item.split("|")])
# 保存结果
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f'results_{timestamp}.log'
with open(output_file, mode='w', encoding='utf-8') as file:
for row in result_array:
file.write(" | ".join(row) + "\n")
logger.info(f"Results saved to {output_file}")
return result_array
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
raise
def main():
try:
# 创建分析器实例
analyzer = ChartAnalyzer()
# 指定图片路径(在Space中使用上传的图片路径)
image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
# 处理图片
results = analyzer.process_image(image_path)
# 打印结果
print("\nAnalysis Results:")
for row in results:
print(" | ".join(row))
except Exception as e:
logger.error(f"Application error: {str(e)}")
raise
if __name__ == "__main__":
main()
|