Sanjayraju30 commited on
Commit
9f46cc7
·
verified ·
1 Parent(s): 6ae35d6

Update ocr_engine.py

Browse files
Files changed (1) hide show
  1. ocr_engine.py +58 -104
ocr_engine.py CHANGED
@@ -1,16 +1,18 @@
1
- import easyocr
2
  import numpy as np
3
  import cv2
4
  import re
5
  import logging
6
  from datetime import datetime
7
  import os
 
 
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
 
12
- # Initialize EasyOCR
13
- easyocr_reader = easyocr.Reader(['en'], gpu=False)
 
14
 
15
  # Directory for debug images
16
  DEBUG_DIR = "debug_images"
@@ -20,7 +22,9 @@ def save_debug_image(img, filename_suffix, prefix=""):
20
  """Save image to debug directory with timestamp."""
21
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
22
  filename = os.path.join(DEBUG_DIR, f"{prefix}{timestamp}_{filename_suffix}.png")
23
- if len(img.shape) == 3:
 
 
24
  cv2.imwrite(filename, cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
25
  else:
26
  cv2.imwrite(filename, img)
@@ -34,30 +38,36 @@ def estimate_brightness(img):
34
  def preprocess_image(img):
35
  """Preprocess image for OCR with enhanced contrast and noise reduction."""
36
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
37
- # Apply Gaussian blur to reduce noise
38
- blurred = cv2.GaussianBlur(gray, (5, 5), 0)
39
- save_debug_image(blurred, "01_preprocess_blur")
40
- # Use adaptive histogram equalization for better contrast
41
- clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
42
- enhanced = clahe.apply(blurred)
43
- save_debug_image(enhanced, "02_preprocess_clahe")
44
- # Morphological operations to enhance digits
 
 
 
 
 
 
45
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
46
- morphed = cv2.morphologyEx(enhanced, cv2.MORPH_CLOSE, kernel)
47
- save_debug_image(morphed, "03_preprocess_morph")
48
- return morphed
 
49
 
50
  def correct_rotation(img):
51
  """Correct image rotation using edge detection."""
52
  try:
53
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
54
- blurred = cv2.GaussianBlur(gray, (5, 5), 0)
55
- edges = cv2.Canny(blurred, 50, 150)
56
  lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=30, maxLineGap=10)
57
  if lines is not None:
58
  angles = [np.arctan2(line[0][3] - line[0][1], line[0][2] - line[0][0]) * 180 / np.pi for line in lines]
59
  angle = np.median(angles)
60
- if abs(angle) > 1.5:
61
  h, w = img.shape[:2]
62
  center = (w // 2, h // 2)
63
  M = cv2.getRotationMatrix2D(center, angle, 1.0)
@@ -73,17 +83,8 @@ def detect_roi(img):
73
  """Detect region of interest (display) with refined contour filtering."""
74
  try:
75
  save_debug_image(img, "04_original")
76
- preprocessed = preprocess_image(img)
77
  brightness_map = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
78
- # Dynamic block size based on image dimensions
79
- block_size = max(11, min(31, int(img.shape[0] / 20) * 2 + 1))
80
- thresh = cv2.adaptiveThreshold(preprocessed, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
81
- cv2.THRESH_BINARY_INV, block_size, 2)
82
- save_debug_image(thresh, "05_roi_threshold")
83
- # Morphological operations to connect digit segments
84
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
85
- thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
86
- save_debug_image(thresh, "06_roi_morph")
87
  contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
88
 
89
  if contours:
@@ -94,101 +95,54 @@ def detect_roi(img):
94
  x, y, w, h = cv2.boundingRect(c)
95
  roi_brightness = np.mean(brightness_map[y:y+h, x:x+w])
96
  aspect_ratio = w / h
97
- # Relaxed constraints for ROI detection
98
- if (100 < area < (img_area * 0.9) and
99
- 0.3 <= aspect_ratio <= 20.0 and w > 40 and h > 15 and roi_brightness > 20):
100
  valid_contours.append((c, area * roi_brightness))
101
  logging.debug(f"Contour: Area={area}, Aspect={aspect_ratio:.2f}, Brightness={roi_brightness:.2f}")
102
 
103
  if valid_contours:
104
  contour, _ = max(valid_contours, key=lambda x: x[1])
105
  x, y, w, h = cv2.boundingRect(contour)
106
- # Dynamic padding based on ROI size
107
- padding = max(10, min(50, int(min(w, h) * 0.2)))
108
  x, y = max(0, x - padding), max(0, y - padding)
109
  w, h = min(w + 2 * padding, img.shape[1] - x), min(h + 2 * padding, img.shape[0] - y)
110
  roi_img = img[y:y+h, x:x+w]
111
- save_debug_image(roi_img, "07_detected_roi")
112
  logging.info(f"Detected ROI: ({x}, {y}, {w}, {h})")
113
  return roi_img, (x, y, w, h)
114
 
115
  logging.info("No ROI found, using full image.")
116
- save_debug_image(img, "07_no_roi_fallback")
117
  return img, None
118
  except Exception as e:
119
  logging.error(f"ROI detection failed: {str(e)}")
120
- save_debug_image(img, "07_roi_error_fallback")
121
  return img, None
122
 
123
- def perform_ocr(img, roi_bbox):
124
- """Perform OCR optimized for digital displays."""
125
  try:
126
- preprocessed = preprocess_image(img)
127
- brightness = estimate_brightness(img)
128
- # Dynamic thresholding based on brightness
129
- thresh_value = 0 if brightness < 50 else (127 if brightness < 100 else 200)
130
- _, thresh = cv2.threshold(preprocessed, thresh_value, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
131
- save_debug_image(thresh, "08_ocr_threshold")
132
- # Morphological operations to clean up digits
133
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
134
- thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
135
- save_debug_image(thresh, "09_ocr_morph")
136
-
137
- # Optimized EasyOCR parameters for seven-segment displays
138
- results = easyocr_reader.readtext(thresh, detail=1, paragraph=False,
139
- contrast_ths=0.1, adjust_contrast=1.5,
140
- text_threshold=0.2, mag_ratio=3.0,
141
- allowlist='0123456789.', batch_size=1, y_ths=0.2)
142
-
143
- logging.info(f"EasyOCR results: {results}")
144
- if not results:
145
- logging.info("No text detected, trying fallback parameters.")
146
- results = easyocr_reader.readtext(thresh, detail=1, paragraph=False,
147
- contrast_ths=0.05, adjust_contrast=2.0,
148
- text_threshold=0.1, mag_ratio=4.0,
149
- allowlist='0123456789.', batch_size=1, y_ths=0.2)
150
- save_debug_image(thresh, "09_fallback_threshold")
151
-
152
- if not results:
153
- logging.info("No digits found.")
154
- return None, 0.0
155
-
156
- digits_info = []
157
- for (bbox, text, conf) in results:
158
- (x1, y1), (x2, y2), (x3, y3), (x4, y4) = bbox
159
- h_bbox = max(y1, y2, y3, y4) - min(y1, y2, y3, y4)
160
- if (text.isdigit() or text == '.') and h_bbox > 5 and conf > 0.1:
161
- x_min, x_max = int(min(x1, x4)), int(max(x2, x3))
162
- y_min, y_max = int(min(y1, y2)), int(max(y3, y4))
163
- digits_info.append((x_min, x_max, y_min, y_max, text, conf))
164
-
165
- if not digits_info:
166
- logging.info("No valid digits after filtering.")
167
- return None, 0.0
168
-
169
- digits_info.sort(key=lambda x: x[0])
170
- recognized_text = ""
171
- total_conf = 0.0
172
- conf_count = 0
173
- for _, _, _, _, char, conf in digits_info:
174
- recognized_text += char
175
- total_conf += conf
176
- conf_count += 1
177
-
178
- avg_conf = total_conf / conf_count if conf_count > 0 else 0.0
179
- logging.info(f"Recognized text: {recognized_text}, Average confidence: {avg_conf:.2f}")
180
-
181
- # Validate and clean the recognized text
182
- text = re.sub(r"[^\d\.]", "", recognized_text)
183
  if text.count('.') > 1:
184
  text = text.replace('.', '', text.count('.') - 1)
185
  text = text.strip('.')
186
  if text and re.fullmatch(r"^\d*\.?\d*$", text):
187
  text = text.lstrip('0') or '0'
188
- if text == '0' and avg_conf < 0.9:
189
- avg_conf *= 0.7
190
- return text, avg_conf * 100
191
- logging.info(f"Text '{recognized_text}' failed validation.")
192
  return None, 0.0
193
  except Exception as e:
194
  logging.error(f"OCR failed: {str(e)}")
@@ -202,13 +156,13 @@ def extract_weight_from_image(pil_img):
202
  save_debug_image(img, "00_input_image")
203
  img = correct_rotation(img)
204
  brightness = estimate_brightness(img)
205
- conf_threshold = 0.5 if brightness > 120 else (0.3 if brightness > 60 else 0.2)
206
 
207
  roi_img, roi_bbox = detect_roi(img)
208
  if roi_bbox:
209
- conf_threshold *= 1.1 if (roi_bbox[2] * roi_bbox[3]) > (img.shape[0] * img.shape[1] * 0.4) else 1.0
210
 
211
- result, confidence = perform_ocr(roi_img, roi_bbox)
212
  if result and confidence >= conf_threshold * 100:
213
  try:
214
  weight = float(result)
@@ -220,8 +174,8 @@ def extract_weight_from_image(pil_img):
220
  logging.warning(f"Invalid weight format: {result}")
221
 
222
  logging.info("Primary OCR failed, using full image fallback.")
223
- result, confidence = perform_ocr(img, None)
224
- if result and confidence >= conf_threshold * 0.8 * 100:
225
  try:
226
  weight = float(result)
227
  if 0.00001 <= weight <= 10000:
 
 
1
  import numpy as np
2
  import cv2
3
  import re
4
  import logging
5
  from datetime import datetime
6
  import os
7
+ from PIL import Image
8
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
+ # Initialize TrOCR
14
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
15
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
16
 
17
  # Directory for debug images
18
  DEBUG_DIR = "debug_images"
 
22
  """Save image to debug directory with timestamp."""
23
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
24
  filename = os.path.join(DEBUG_DIR, f"{prefix}{timestamp}_{filename_suffix}.png")
25
+ if isinstance(img, Image.Image):
26
+ img.save(filename)
27
+ elif len(img.shape) == 3:
28
  cv2.imwrite(filename, cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
29
  else:
30
  cv2.imwrite(filename, img)
 
38
  def preprocess_image(img):
39
  """Preprocess image for OCR with enhanced contrast and noise reduction."""
40
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
41
+ # Dynamic contrast adjustment based on brightness
42
+ brightness = estimate_brightness(img)
43
+ clahe_clip = 4.0 if brightness < 100 else 2.0
44
+ clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=(8, 8))
45
+ enhanced = clahe.apply(gray)
46
+ save_debug_image(enhanced, "01_preprocess_clahe")
47
+ # Gaussian blur to reduce noise
48
+ blurred = cv2.GaussianBlur(enhanced, (3, 3), 0)
49
+ save_debug_image(blurred, "02_preprocess_blur")
50
+ # Adaptive thresholding for digit segmentation
51
+ block_size = max(11, min(31, int(img.shape[0] / 20) * 2 + 1))
52
+ thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
53
+ cv2.THRESH_BINARY_INV, block_size, 2)
54
+ # Morphological operations to clean up digits
55
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
56
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1)
57
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=1)
58
+ save_debug_image(thresh, "03_preprocess_morph")
59
+ return thresh, enhanced
60
 
61
  def correct_rotation(img):
62
  """Correct image rotation using edge detection."""
63
  try:
64
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
65
+ edges = cv2.Canny(gray, 50, 150)
 
66
  lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=30, maxLineGap=10)
67
  if lines is not None:
68
  angles = [np.arctan2(line[0][3] - line[0][1], line[0][2] - line[0][0]) * 180 / np.pi for line in lines]
69
  angle = np.median(angles)
70
+ if abs(angle) > 1.0:
71
  h, w = img.shape[:2]
72
  center = (w // 2, h // 2)
73
  M = cv2.getRotationMatrix2D(center, angle, 1.0)
 
83
  """Detect region of interest (display) with refined contour filtering."""
84
  try:
85
  save_debug_image(img, "04_original")
86
+ thresh, enhanced = preprocess_image(img)
87
  brightness_map = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
 
 
 
 
 
 
 
 
 
88
  contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
89
 
90
  if contours:
 
95
  x, y, w, h = cv2.boundingRect(c)
96
  roi_brightness = np.mean(brightness_map[y:y+h, x:x+w])
97
  aspect_ratio = w / h
98
+ # Relaxed constraints for digital displays
99
+ if (200 < area < (img_area * 0.8) and
100
+ 0.5 <= aspect_ratio <= 15.0 and w > 50 and h > 20 and roi_brightness > 30):
101
  valid_contours.append((c, area * roi_brightness))
102
  logging.debug(f"Contour: Area={area}, Aspect={aspect_ratio:.2f}, Brightness={roi_brightness:.2f}")
103
 
104
  if valid_contours:
105
  contour, _ = max(valid_contours, key=lambda x: x[1])
106
  x, y, w, h = cv2.boundingRect(contour)
107
+ padding = max(15, min(50, int(min(w, h) * 0.3)))
 
108
  x, y = max(0, x - padding), max(0, y - padding)
109
  w, h = min(w + 2 * padding, img.shape[1] - x), min(h + 2 * padding, img.shape[0] - y)
110
  roi_img = img[y:y+h, x:x+w]
111
+ save_debug_image(roi_img, "05_detected_roi")
112
  logging.info(f"Detected ROI: ({x}, {y}, {w}, {h})")
113
  return roi_img, (x, y, w, h)
114
 
115
  logging.info("No ROI found, using full image.")
116
+ save_debug_image(img, "05_no_roi_fallback")
117
  return img, None
118
  except Exception as e:
119
  logging.error(f"ROI detection failed: {str(e)}")
120
+ save_debug_image(img, "05_roi_error_fallback")
121
  return img, None
122
 
123
+ def perform_ocr(img):
124
+ """Perform OCR using TrOCR for digital displays."""
125
  try:
126
+ # Convert to PIL for TrOCR
127
+ pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
128
+ save_debug_image(pil_img, "06_ocr_input")
129
+ # Process image with TrOCR
130
+ pixel_values = processor(pil_img, return_tensors="pt").pixel_values
131
+ generated_ids = model.generate(pixel_values)
132
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
133
+ logging.info(f"TrOCR raw output: {text}")
134
+
135
+ # Clean and validate text
136
+ text = re.sub(r"[^\d\.]", "", text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  if text.count('.') > 1:
138
  text = text.replace('.', '', text.count('.') - 1)
139
  text = text.strip('.')
140
  if text and re.fullmatch(r"^\d*\.?\d*$", text):
141
  text = text.lstrip('0') or '0'
142
+ confidence = 95.0 if len(text.replace('.', '')) > 1 else 90.0
143
+ logging.info(f"Validated text: {text}, Confidence: {confidence:.2f}%")
144
+ return text, confidence
145
+ logging.info(f"Text '{text}' failed validation.")
146
  return None, 0.0
147
  except Exception as e:
148
  logging.error(f"OCR failed: {str(e)}")
 
156
  save_debug_image(img, "00_input_image")
157
  img = correct_rotation(img)
158
  brightness = estimate_brightness(img)
159
+ conf_threshold = 0.6 if brightness > 100 else 0.4
160
 
161
  roi_img, roi_bbox = detect_roi(img)
162
  if roi_bbox:
163
+ conf_threshold *= 1.2 if (roi_bbox[2] * roi_bbox[3]) > (img.shape[0] * img.shape[1] * 0.3) else 1.0
164
 
165
+ result, confidence = perform_ocr(roi_img)
166
  if result and confidence >= conf_threshold * 100:
167
  try:
168
  weight = float(result)
 
174
  logging.warning(f"Invalid weight format: {result}")
175
 
176
  logging.info("Primary OCR failed, using full image fallback.")
177
+ result, confidence = perform_ocr(img)
178
+ if result and confidence >= conf_threshold * 0.9 * 100:
179
  try:
180
  weight = float(result)
181
  if 0.00001 <= weight <= 10000: