File size: 4,650 Bytes
9d3f1f1 0980890 eee8d68 9d3f1f1 0980890 eee8d68 9d3f1f1 eee8d68 9d3f1f1 eee8d68 9d3f1f1 eee8d68 9d3f1f1 0980890 9d3f1f1 eee8d68 9d3f1f1 eee8d68 0980890 9d3f1f1 eee8d68 9d3f1f1 ae3e99f 9d3f1f1 eee8d68 9d3f1f1 ae3e99f 9d3f1f1 eee8d68 9d3f1f1 eee8d68 9d3f1f1 eee8d68 9d3f1f1 eee8d68 9d3f1f1 eee8d68 9d3f1f1 eee8d68 9d3f1f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import sys
import logging
from datetime import datetime
import torch
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from PIL import Image
# 配置日志格式
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('app.log')
]
)
logger = logging.getLogger(__name__)
def print_section(title, char='='):
"""打印格式化的章节标题"""
print(f"\n{char * 50}")
print(f"{title.center(50)}")
print(f"{char * 50}\n")
def print_table(data):
"""格式化打印表格数据"""
if not data:
print("No data available")
return
# 计算每列的最大宽度
col_widths = []
for i in range(len(data[0])):
col_width = max(len(str(row[i])) for row in data)
col_widths.append(col_width)
# 打印表头
header = data[0]
header_str = " | ".join(str(header[i]).ljust(col_widths[i]) for i in range(len(header)))
print(header_str)
print("-" * len(header_str))
# 打印数据行
for row in data[1:]:
row_str = " | ".join(str(row[i]).ljust(col_widths[i]) for i in range(len(row)))
print(row_str)
class ChartAnalyzer:
def __init__(self):
try:
print_section("初始化模型")
print("正在加载模型和处理器...")
self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
self.processor = AutoProcessor.from_pretrained("google/deplot")
print("✓ 模型加载完成")
except Exception as e:
print("✗ 模型加载失败")
logger.error(f"Error initializing model: {str(e)}")
raise
def process_image(self, image_path, prompt=None):
"""处理图片并生成数据表格"""
try:
print_section("图片处理", char='-')
# 验证文件存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"找不到图片文件: {image_path}")
# 打开并处理图片
print(f"正在处理图片: {image_path}")
image = Image.open(image_path)
# 准备输入
if prompt is None:
prompt = "Generate underlying data table of the figure below:"
inputs = self.processor(
images=image,
text=prompt,
return_tensors="pt"
)
# 生成预测
print("\n正在生成数据分析...")
with torch.no_grad():
predictions = self.model.generate(
**inputs,
max_new_tokens=512,
num_beams=4,
length_penalty=1.0
)
# 解码预测结果
raw_output = self.processor.decode(predictions[0], skip_special_tokens=True)
# 处理结果
split_by_newline = raw_output.split("<0x0A>")
result_array = []
for item in split_by_newline:
if item.strip(): # 跳过空行
result_array.append([x.strip() for x in item.split("|")])
# 保存结果
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f'results_{timestamp}.log'
with open(output_file, mode='w', encoding='utf-8') as file:
for row in result_array:
file.write(" | ".join(row) + "\n")
print(f"\n✓ 结果已保存至: {output_file}")
return result_array
except Exception as e:
print("\n✗ 处理失败")
logger.error(f"Error processing image: {str(e)}")
raise
def main():
try:
print_section("图表数据提取系统", char='*')
# 创建分析器实例
analyzer = ChartAnalyzer()
# 指定图片路径
image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
# 处理图片
results = analyzer.process_image(image_path)
# 打印结果
print_section("分析结果")
print_table(results)
print_section("处理完成", char='*')
except Exception as e:
logger.error(f"Application error: {str(e)}")
print("\n✗ 程序执行出错,请查看日志获取详细信息")
raise
if __name__ == "__main__":
main()
|