Spaces:
Paused
Paused
# utils.py | |
import os | |
from transformers import AutoModel, AutoTokenizer | |
from PIL import Image, ImageEnhance, ImageFilter | |
import torch | |
import logging | |
from transformers import BertTokenizer | |
logger = logging.getLogger(__name__) | |
class OCRModel: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(OCRModel, cls).__new__(cls) | |
cls._instance.initialize() | |
return cls._instance | |
def initialize(self): | |
try: | |
logger.info("Initializing OCR model...") | |
# Try different tokenizer approaches | |
try: | |
# First try with the standard approach | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
'stepfun-ai/GOT-OCR2_0', | |
trust_remote_code=True, | |
use_fast=False | |
) | |
except Exception as e: | |
logger.warning(f"Standard tokenizer failed, trying BertTokenizer: {str(e)}") | |
# Fall back to BertTokenizer if AutoTokenizer fails | |
self.tokenizer = BertTokenizer.from_pretrained( | |
'stepfun-ai/GOT-OCR2_0', | |
trust_remote_code=True | |
) | |
self.model = AutoModel.from_pretrained( | |
'stepfun-ai/GOT-OCR2_0', | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
device_map='auto', | |
use_safetensors=True | |
) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = self.model.eval().to(self.device) | |
logger.info("Model initialization completed successfully") | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}", exc_info=True) | |
raise | |
def preprocess_image(self, image): | |
"""Image preprocessing to improve text recognition quality""" | |
try: | |
# Convert image to RGB if it is not already | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Improve contrast | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(1.5) | |
# Improve Sharpness | |
enhancer = ImageEnhance.Sharpness(image) | |
image = enhancer.enhance(1.5) | |
# Improve Brightness | |
enhancer = ImageEnhance.Brightness(image) | |
image = enhancer.enhance(1.2) | |
# Apply a filter to soften the image a little. | |
image = image.filter(ImageFilter.SMOOTH) | |
return image | |
except Exception as e: | |
logger.error(f"Error in image preprocessing: {str(e)}", exc_info=True) | |
raise | |
def process_image(self, image_stream): | |
try: | |
logger.info("Starting image processing") | |
# Save image temporarily because the model requires a file path. | |
temp_image_path = "temp_image.jpg" | |
# Reset the start pointer for BytesIO | |
image_stream.seek(0) | |
# Open and save the image temporarily. | |
image = Image.open(image_stream).convert('RGB') | |
processed_image = self.preprocess_image(image) | |
processed_image.save(temp_image_path) | |
# ocr | |
try: | |
result = self.model.chat(self.tokenizer, temp_image_path, ocr_type='format') | |
logger.info(f"Successfully extracted text: {result[:100]}...") | |
# Delete temporary file | |
if os.path.exists(temp_image_path): | |
os.remove(temp_image_path) | |
return result.strip() | |
except Exception as e: | |
logger.error(f"Error in OCR processing: {str(e)}", exc_info=True) | |
if os.path.exists(temp_image_path): | |
os.remove(temp_image_path) | |
raise | |
except Exception as e: | |
logger.error(f"Error in image processing: {str(e)}", exc_info=True) | |
return f"Error processing image: {str(e)}" |