c2t / app.py
realkun's picture
Update app.py
9d3f1f1 verified
raw
history blame
3.55 kB
import os
import sys
import logging
from datetime import datetime
import torch
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from PIL import Image
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('app.log')
]
)
logger = logging.getLogger(__name__)
class ChartAnalyzer:
def __init__(self):
try:
logger.info("Initializing model and processor...")
self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
self.processor = AutoProcessor.from_pretrained("google/deplot")
logger.info("Model and processor initialized successfully")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def process_image(self, image_path, prompt=None):
"""处理图片并生成数据表格"""
try:
# 验证文件存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
# 打开并处理图片
logger.info(f"Processing image: {image_path}")
image = Image.open(image_path)
# 准备输入
if prompt is None:
prompt = "Generate underlying data table of the figure below:"
inputs = self.processor(
images=image,
text=prompt,
return_tensors="pt"
)
# 生成预测
logger.info("Generating predictions...")
with torch.no_grad(): # 提高性能并减少内存使用
predictions = self.model.generate(
**inputs,
max_new_tokens=512,
num_beams=4,
length_penalty=1.0
)
# 解码预测结果
raw_output = self.processor.decode(predictions[0], skip_special_tokens=True)
# 处理结果
split_by_newline = raw_output.split("<0x0A>")
result_array = []
for item in split_by_newline:
if item.strip(): # 跳过空行
result_array.append([x.strip() for x in item.split("|")])
# 保存结果
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f'results_{timestamp}.log'
with open(output_file, mode='w', encoding='utf-8') as file:
for row in result_array:
file.write(" | ".join(row) + "\n")
logger.info(f"Results saved to {output_file}")
return result_array
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
raise
def main():
try:
# 创建分析器实例
analyzer = ChartAnalyzer()
# 指定图片路径(在Space中使用上传的图片路径)
image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
# 处理图片
results = analyzer.process_image(image_path)
# 打印结果
print("\nAnalysis Results:")
for row in results:
print(" | ".join(row))
except Exception as e:
logger.error(f"Application error: {str(e)}")
raise
if __name__ == "__main__":
main()