ocrsensitive / app /utils.py
mashaelalbu's picture
Update app/utils.py
c10f4ec verified
# utils.py
import os
from transformers import AutoModel, AutoTokenizer
from PIL import Image, ImageEnhance, ImageFilter
import torch
import logging
from transformers import BertTokenizer
logger = logging.getLogger(__name__)
class OCRModel:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(OCRModel, cls).__new__(cls)
cls._instance.initialize()
return cls._instance
def initialize(self):
try:
logger.info("Initializing OCR model...")
# Try different tokenizer approaches
try:
# First try with the standard approach
self.tokenizer = AutoTokenizer.from_pretrained(
'stepfun-ai/GOT-OCR2_0',
trust_remote_code=True,
use_fast=False
)
except Exception as e:
logger.warning(f"Standard tokenizer failed, trying BertTokenizer: {str(e)}")
# Fall back to BertTokenizer if AutoTokenizer fails
self.tokenizer = BertTokenizer.from_pretrained(
'stepfun-ai/GOT-OCR2_0',
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
'stepfun-ai/GOT-OCR2_0',
trust_remote_code=True,
low_cpu_mem_usage=True,
device_map='auto',
use_safetensors=True
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.model.eval().to(self.device)
logger.info("Model initialization completed successfully")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}", exc_info=True)
raise
def preprocess_image(self, image):
"""Image preprocessing to improve text recognition quality"""
try:
# Convert image to RGB if it is not already
if image.mode != 'RGB':
image = image.convert('RGB')
# Improve contrast
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(1.5)
# Improve Sharpness
enhancer = ImageEnhance.Sharpness(image)
image = enhancer.enhance(1.5)
# Improve Brightness
enhancer = ImageEnhance.Brightness(image)
image = enhancer.enhance(1.2)
# Apply a filter to soften the image a little.
image = image.filter(ImageFilter.SMOOTH)
return image
except Exception as e:
logger.error(f"Error in image preprocessing: {str(e)}", exc_info=True)
raise
def process_image(self, image_stream):
try:
logger.info("Starting image processing")
# Save image temporarily because the model requires a file path.
temp_image_path = "temp_image.jpg"
# Reset the start pointer for BytesIO
image_stream.seek(0)
# Open and save the image temporarily.
image = Image.open(image_stream).convert('RGB')
processed_image = self.preprocess_image(image)
processed_image.save(temp_image_path)
# ocr
try:
result = self.model.chat(self.tokenizer, temp_image_path, ocr_type='format')
logger.info(f"Successfully extracted text: {result[:100]}...")
# Delete temporary file
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
return result.strip()
except Exception as e:
logger.error(f"Error in OCR processing: {str(e)}", exc_info=True)
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
raise
except Exception as e:
logger.error(f"Error in image processing: {str(e)}", exc_info=True)
return f"Error processing image: {str(e)}"