AutoWeightLogger1 / ocr_engine.py
Sanjayraju30's picture
Update ocr_engine.py
9f46cc7 verified
raw
history blame
8.54 kB
import numpy as np
import cv2
import re
import logging
from datetime import datetime
import os
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Initialize TrOCR
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
# Directory for debug images
DEBUG_DIR = "debug_images"
os.makedirs(DEBUG_DIR, exist_ok=True)
def save_debug_image(img, filename_suffix, prefix=""):
"""Save image to debug directory with timestamp."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = os.path.join(DEBUG_DIR, f"{prefix}{timestamp}_{filename_suffix}.png")
if isinstance(img, Image.Image):
img.save(filename)
elif len(img.shape) == 3:
cv2.imwrite(filename, cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
else:
cv2.imwrite(filename, img)
logging.info(f"Saved debug image: {filename}")
def estimate_brightness(img):
"""Estimate image brightness."""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return np.mean(gray)
def preprocess_image(img):
"""Preprocess image for OCR with enhanced contrast and noise reduction."""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Dynamic contrast adjustment based on brightness
brightness = estimate_brightness(img)
clahe_clip = 4.0 if brightness < 100 else 2.0
clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
save_debug_image(enhanced, "01_preprocess_clahe")
# Gaussian blur to reduce noise
blurred = cv2.GaussianBlur(enhanced, (3, 3), 0)
save_debug_image(blurred, "02_preprocess_blur")
# Adaptive thresholding for digit segmentation
block_size = max(11, min(31, int(img.shape[0] / 20) * 2 + 1))
thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, block_size, 2)
# Morphological operations to clean up digits
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1)
thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=1)
save_debug_image(thresh, "03_preprocess_morph")
return thresh, enhanced
def correct_rotation(img):
"""Correct image rotation using edge detection."""
try:
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150)
lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=30, maxLineGap=10)
if lines is not None:
angles = [np.arctan2(line[0][3] - line[0][1], line[0][2] - line[0][0]) * 180 / np.pi for line in lines]
angle = np.median(angles)
if abs(angle) > 1.0:
h, w = img.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
img = cv2.warpAffine(img, M, (w, h))
save_debug_image(img, "00_rotated_image")
logging.info(f"Applied rotation: {angle:.2f} degrees")
return img
except Exception as e:
logging.error(f"Rotation correction failed: {str(e)}")
return img
def detect_roi(img):
"""Detect region of interest (display) with refined contour filtering."""
try:
save_debug_image(img, "04_original")
thresh, enhanced = preprocess_image(img)
brightness_map = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
img_area = img.shape[0] * img.shape[1]
valid_contours = []
for c in contours:
area = cv2.contourArea(c)
x, y, w, h = cv2.boundingRect(c)
roi_brightness = np.mean(brightness_map[y:y+h, x:x+w])
aspect_ratio = w / h
# Relaxed constraints for digital displays
if (200 < area < (img_area * 0.8) and
0.5 <= aspect_ratio <= 15.0 and w > 50 and h > 20 and roi_brightness > 30):
valid_contours.append((c, area * roi_brightness))
logging.debug(f"Contour: Area={area}, Aspect={aspect_ratio:.2f}, Brightness={roi_brightness:.2f}")
if valid_contours:
contour, _ = max(valid_contours, key=lambda x: x[1])
x, y, w, h = cv2.boundingRect(contour)
padding = max(15, min(50, int(min(w, h) * 0.3)))
x, y = max(0, x - padding), max(0, y - padding)
w, h = min(w + 2 * padding, img.shape[1] - x), min(h + 2 * padding, img.shape[0] - y)
roi_img = img[y:y+h, x:x+w]
save_debug_image(roi_img, "05_detected_roi")
logging.info(f"Detected ROI: ({x}, {y}, {w}, {h})")
return roi_img, (x, y, w, h)
logging.info("No ROI found, using full image.")
save_debug_image(img, "05_no_roi_fallback")
return img, None
except Exception as e:
logging.error(f"ROI detection failed: {str(e)}")
save_debug_image(img, "05_roi_error_fallback")
return img, None
def perform_ocr(img):
"""Perform OCR using TrOCR for digital displays."""
try:
# Convert to PIL for TrOCR
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
save_debug_image(pil_img, "06_ocr_input")
# Process image with TrOCR
pixel_values = processor(pil_img, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
logging.info(f"TrOCR raw output: {text}")
# Clean and validate text
text = re.sub(r"[^\d\.]", "", text)
if text.count('.') > 1:
text = text.replace('.', '', text.count('.') - 1)
text = text.strip('.')
if text and re.fullmatch(r"^\d*\.?\d*$", text):
text = text.lstrip('0') or '0'
confidence = 95.0 if len(text.replace('.', '')) > 1 else 90.0
logging.info(f"Validated text: {text}, Confidence: {confidence:.2f}%")
return text, confidence
logging.info(f"Text '{text}' failed validation.")
return None, 0.0
except Exception as e:
logging.error(f"OCR failed: {str(e)}")
return None, 0.0
def extract_weight_from_image(pil_img):
"""Extract weight from a digital scale image."""
try:
img = np.array(pil_img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
save_debug_image(img, "00_input_image")
img = correct_rotation(img)
brightness = estimate_brightness(img)
conf_threshold = 0.6 if brightness > 100 else 0.4
roi_img, roi_bbox = detect_roi(img)
if roi_bbox:
conf_threshold *= 1.2 if (roi_bbox[2] * roi_bbox[3]) > (img.shape[0] * img.shape[1] * 0.3) else 1.0
result, confidence = perform_ocr(roi_img)
if result and confidence >= conf_threshold * 100:
try:
weight = float(result)
if 0.00001 <= weight <= 10000:
logging.info(f"Detected weight: {result} kg, Confidence: {confidence:.2f}%")
return result, confidence
logging.warning(f"Weight {result} out of range.")
except ValueError:
logging.warning(f"Invalid weight format: {result}")
logging.info("Primary OCR failed, using full image fallback.")
result, confidence = perform_ocr(img)
if result and confidence >= conf_threshold * 0.9 * 100:
try:
weight = float(result)
if 0.00001 <= weight <= 10000:
logging.info(f"Full image weight: {result} kg, Confidence: {confidence:.2f}%")
return result, confidence
logging.warning(f"Full image weight {result} out of range.")
except ValueError:
logging.warning(f"Invalid full image weight format: {result}")
logging.info("No valid weight detected.")
return "Not detected", 0.0
except Exception as e:
logging.error(f"Weight extraction failed: {str(e)}")
return "Not detected", 0.0