c2t / app.py
realkun's picture
Update app.py
eee8d68 verified
raw
history blame
4.65 kB
import os
import sys
import logging
from datetime import datetime
import torch
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from PIL import Image
# 配置日志格式
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('app.log')
]
)
logger = logging.getLogger(__name__)
def print_section(title, char='='):
"""打印格式化的章节标题"""
print(f"\n{char * 50}")
print(f"{title.center(50)}")
print(f"{char * 50}\n")
def print_table(data):
"""格式化打印表格数据"""
if not data:
print("No data available")
return
# 计算每列的最大宽度
col_widths = []
for i in range(len(data[0])):
col_width = max(len(str(row[i])) for row in data)
col_widths.append(col_width)
# 打印表头
header = data[0]
header_str = " | ".join(str(header[i]).ljust(col_widths[i]) for i in range(len(header)))
print(header_str)
print("-" * len(header_str))
# 打印数据行
for row in data[1:]:
row_str = " | ".join(str(row[i]).ljust(col_widths[i]) for i in range(len(row)))
print(row_str)
class ChartAnalyzer:
def __init__(self):
try:
print_section("初始化模型")
print("正在加载模型和处理器...")
self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
self.processor = AutoProcessor.from_pretrained("google/deplot")
print("✓ 模型加载完成")
except Exception as e:
print("✗ 模型加载失败")
logger.error(f"Error initializing model: {str(e)}")
raise
def process_image(self, image_path, prompt=None):
"""处理图片并生成数据表格"""
try:
print_section("图片处理", char='-')
# 验证文件存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"找不到图片文件: {image_path}")
# 打开并处理图片
print(f"正在处理图片: {image_path}")
image = Image.open(image_path)
# 准备输入
if prompt is None:
prompt = "Generate underlying data table of the figure below:"
inputs = self.processor(
images=image,
text=prompt,
return_tensors="pt"
)
# 生成预测
print("\n正在生成数据分析...")
with torch.no_grad():
predictions = self.model.generate(
**inputs,
max_new_tokens=512,
num_beams=4,
length_penalty=1.0
)
# 解码预测结果
raw_output = self.processor.decode(predictions[0], skip_special_tokens=True)
# 处理结果
split_by_newline = raw_output.split("<0x0A>")
result_array = []
for item in split_by_newline:
if item.strip(): # 跳过空行
result_array.append([x.strip() for x in item.split("|")])
# 保存结果
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f'results_{timestamp}.log'
with open(output_file, mode='w', encoding='utf-8') as file:
for row in result_array:
file.write(" | ".join(row) + "\n")
print(f"\n✓ 结果已保存至: {output_file}")
return result_array
except Exception as e:
print("\n✗ 处理失败")
logger.error(f"Error processing image: {str(e)}")
raise
def main():
try:
print_section("图表数据提取系统", char='*')
# 创建分析器实例
analyzer = ChartAnalyzer()
# 指定图片路径
image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
# 处理图片
results = analyzer.process_image(image_path)
# 打印结果
print_section("分析结果")
print_table(results)
print_section("处理完成", char='*')
except Exception as e:
logger.error(f"Application error: {str(e)}")
print("\n✗ 程序执行出错,请查看日志获取详细信息")
raise
if __name__ == "__main__":
main()