|
import os |
|
import sys |
|
import logging |
|
from datetime import datetime |
|
import torch |
|
from transformers import AutoProcessor, Pix2StructForConditionalGeneration |
|
from PIL import Image |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout), |
|
logging.FileHandler('app.log') |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ChartAnalyzer: |
|
def __init__(self): |
|
try: |
|
logger.info("Initializing model and processor...") |
|
self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot") |
|
self.processor = AutoProcessor.from_pretrained("google/deplot") |
|
logger.info("Model and processor initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Error initializing model: {str(e)}") |
|
raise |
|
|
|
def process_image(self, image_path, prompt=None): |
|
"""处理图片并生成数据表格""" |
|
try: |
|
|
|
if not os.path.exists(image_path): |
|
raise FileNotFoundError(f"Image file not found: {image_path}") |
|
|
|
|
|
logger.info(f"Processing image: {image_path}") |
|
image = Image.open(image_path) |
|
|
|
|
|
if prompt is None: |
|
prompt = "Generate underlying data table of the figure below:" |
|
|
|
inputs = self.processor( |
|
images=image, |
|
text=prompt, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
logger.info("Generating predictions...") |
|
with torch.no_grad(): |
|
predictions = self.model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
num_beams=4, |
|
length_penalty=1.0 |
|
) |
|
|
|
|
|
raw_output = self.processor.decode(predictions[0], skip_special_tokens=True) |
|
|
|
|
|
split_by_newline = raw_output.split("<0x0A>") |
|
result_array = [] |
|
for item in split_by_newline: |
|
if item.strip(): |
|
result_array.append([x.strip() for x in item.split("|")]) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_file = f'results_{timestamp}.log' |
|
|
|
with open(output_file, mode='w', encoding='utf-8') as file: |
|
for row in result_array: |
|
file.write(" | ".join(row) + "\n") |
|
|
|
logger.info(f"Results saved to {output_file}") |
|
return result_array |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing image: {str(e)}") |
|
raise |
|
|
|
def main(): |
|
try: |
|
|
|
analyzer = ChartAnalyzer() |
|
|
|
|
|
image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg' |
|
|
|
|
|
results = analyzer.process_image(image_path) |
|
|
|
|
|
print("\nAnalysis Results:") |
|
for row in results: |
|
print(" | ".join(row)) |
|
|
|
except Exception as e: |
|
logger.error(f"Application error: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|