Sanjayraju30 commited on
Commit
ded0d50
·
verified ·
1 Parent(s): 013fbf8

Update ocr_engine.py

Browse files
Files changed (1) hide show
  1. ocr_engine.py +177 -69
ocr_engine.py CHANGED
@@ -1,24 +1,21 @@
 
1
  import numpy as np
2
  import cv2
3
  import re
4
  import logging
5
  from datetime import datetime
6
  import os
7
- from PIL import Image
8
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
- # Initialize TrOCR with error handling
14
  try:
15
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
16
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
17
- logging.info("TrOCR model and processor loaded successfully")
18
  except Exception as e:
19
- logging.error(f"Failed to load TrOCR model: {str(e)}")
20
- processor = None
21
- model = None
22
 
23
  # Directory for debug images
24
  DEBUG_DIR = "debug_images"
@@ -28,9 +25,7 @@ def save_debug_image(img, filename_suffix, prefix=""):
28
  """Save image to debug directory with timestamp."""
29
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
30
  filename = os.path.join(DEBUG_DIR, f"{prefix}{timestamp}_{filename_suffix}.png")
31
- if isinstance(img, Image.Image):
32
- img.save(filename)
33
- elif len(img.shape) == 3:
34
  cv2.imwrite(filename, cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
35
  else:
36
  cv2.imwrite(filename, img)
@@ -44,23 +39,23 @@ def estimate_brightness(img):
44
  def preprocess_image(img):
45
  """Preprocess image for OCR with enhanced contrast and noise reduction."""
46
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
47
- # Dynamic contrast adjustment based on brightness
48
  brightness = estimate_brightness(img)
49
- clahe_clip = 4.0 if brightness < 100 else 2.0
 
50
  clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=(8, 8))
51
  enhanced = clahe.apply(gray)
52
  save_debug_image(enhanced, "01_preprocess_clahe")
53
  # Gaussian blur to reduce noise
54
  blurred = cv2.GaussianBlur(enhanced, (3, 3), 0)
55
  save_debug_image(blurred, "02_preprocess_blur")
56
- # Adaptive thresholding for digit segmentation
57
- block_size = max(11, min(31, int(img.shape[0] / 20) * 2 + 1))
58
  thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
59
- cv2.THRESH_BINARY_INV, block_size, 2)
60
- # Morphological operations to clean up digits
61
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
62
  thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1)
63
- thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=1)
64
  save_debug_image(thresh, "03_preprocess_morph")
65
  return thresh, enhanced
66
 
@@ -68,7 +63,7 @@ def correct_rotation(img):
68
  """Correct image rotation using edge detection."""
69
  try:
70
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
71
- edges = cv2.Canny(gray, 50, 150)
72
  lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=30, maxLineGap=10)
73
  if lines is not None:
74
  angles = [np.arctan2(line[0][3] - line[0][1], line[0][2] - line[0][0]) * 180 / np.pi for line in lines]
@@ -86,72 +81,185 @@ def correct_rotation(img):
86
  return img
87
 
88
  def detect_roi(img):
89
- """Detect region of interest (display) with refined contour filtering."""
90
  try:
91
  save_debug_image(img, "04_original")
92
  thresh, enhanced = preprocess_image(img)
93
  brightness_map = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
94
- contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
 
95
 
96
- if contours:
97
- img_area = img.shape[0] * img.shape[1]
98
- valid_contours = []
 
 
 
 
 
99
  for c in contours:
100
  area = cv2.contourArea(c)
101
  x, y, w, h = cv2.boundingRect(c)
102
  roi_brightness = np.mean(brightness_map[y:y+h, x:x+w])
103
  aspect_ratio = w / h
104
- # Relaxed constraints for digital displays
105
- if (200 < area < (img_area * 0.8) and
106
- 0.5 <= aspect_ratio <= 15.0 and w > 50 and h > 20 and roi_brightness > 30):
107
  valid_contours.append((c, area * roi_brightness))
108
- logging.debug(f"Contour: Area={area}, Aspect={aspect_ratio:.2f}, Brightness={roi_brightness:.2f}")
109
-
110
- if valid_contours:
111
- contour, _ = max(valid_contours, key=lambda x: x[1])
112
- x, y, w, h = cv2.boundingRect(contour)
113
- padding = max(15, min(50, int(min(w, h) * 0.3)))
114
- x, y = max(0, x - padding), max(0, y - padding)
115
- w, h = min(w + 2 * padding, img.shape[1] - x), min(h + 2 * padding, img.shape[0] - y)
116
- roi_img = img[y:y+h, x:x+w]
117
- save_debug_image(roi_img, "05_detected_roi")
118
- logging.info(f"Detected ROI: ({x}, {y}, {w}, {h})")
119
- return roi_img, (x, y, w, h)
120
 
121
  logging.info("No ROI found, using full image.")
122
- save_debug_image(img, "05_no_roi_fallback")
123
  return img, None
124
  except Exception as e:
125
  logging.error(f"ROI detection failed: {str(e)}")
126
- save_debug_image(img, "05_roi_error_fallback")
127
  return img, None
128
 
129
- def perform_ocr(img):
130
- """Perform OCR using TrOCR for digital displays."""
131
- if processor is None or model is None:
132
- logging.error("TrOCR model not loaded, cannot perform OCR.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  return None, 0.0
134
  try:
135
- # Convert to PIL for TrOCR
136
- pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
137
- save_debug_image(pil_img, "06_ocr_input")
138
- # Process image with TrOCR
139
- pixel_values = processor(pil_img, return_tensors="pt").pixel_values
140
- generated_ids = model.generate(pixel_values, max_length=10)
141
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
142
- logging.info(f"TrOCR raw output: {text}")
143
-
144
- # Clean and validate text
145
- text = re.sub(r"[^\d\.]", "", text)
146
- if text.count('.') > 1:
147
- text = text.replace('.', '', text.count('.') - 1)
148
- text = text.strip('.')
149
- if text and re.fullmatch(r"^\d*\.?\d*$", text):
150
- text = text.lstrip('0') or '0'
151
- confidence = 95.0 if len(text.replace('.', '')) >= 2 else 85.0
152
- logging.info(f"Validated text: {text}, Confidence: {confidence:.2f}%")
153
- return text, confidence
154
- logging.info(f"Text '{text}' failed validation.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  return None, 0.0
156
  except Exception as e:
157
  logging.error(f"OCR failed: {str(e)}")
@@ -171,11 +279,11 @@ def extract_weight_from_image(pil_img):
171
  if roi_bbox:
172
  conf_threshold *= 1.1 if (roi_bbox[2] * roi_bbox[3]) > (img.shape[0] * img.shape[1] * 0.3) else 1.0
173
 
174
- result, confidence = perform_ocr(roi_img)
175
  if result and confidence >= conf_threshold * 100:
176
  try:
177
  weight = float(result)
178
- if 0.01 <= weight <= 1000: # Narrowed range for typical scale weights
179
  logging.info(f"Detected weight: {result} kg, Confidence: {confidence:.2f}%")
180
  return result, confidence
181
  logging.warning(f"Weight {result} out of range.")
@@ -183,7 +291,7 @@ def extract_weight_from_image(pil_img):
183
  logging.warning(f"Invalid weight format: {result}")
184
 
185
  logging.info("Primary OCR failed, using full image fallback.")
186
- result, confidence = perform_ocr(img)
187
  if result and confidence >= conf_threshold * 0.9 * 100:
188
  try:
189
  weight = float(result)
 
1
+ import easyocr
2
  import numpy as np
3
  import cv2
4
  import re
5
  import logging
6
  from datetime import datetime
7
  import os
 
 
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
 
12
+ # Initialize EasyOCR
13
  try:
14
+ easyocr_reader = easyocr.Reader(['en'], gpu=False)
15
+ logging.info("EasyOCR initialized successfully")
 
16
  except Exception as e:
17
+ logging.error(f"Failed to initialize EasyOCR: {str(e)}")
18
+ easyocr_reader = None
 
19
 
20
  # Directory for debug images
21
  DEBUG_DIR = "debug_images"
 
25
  """Save image to debug directory with timestamp."""
26
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
27
  filename = os.path.join(DEBUG_DIR, f"{prefix}{timestamp}_{filename_suffix}.png")
28
+ if len(img.shape) == 3:
 
 
29
  cv2.imwrite(filename, cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
30
  else:
31
  cv2.imwrite(filename, img)
 
39
  def preprocess_image(img):
40
  """Preprocess image for OCR with enhanced contrast and noise reduction."""
41
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
 
42
  brightness = estimate_brightness(img)
43
+ # Dynamic CLAHE based on brightness
44
+ clahe_clip = 4.0 if brightness < 80 else 2.0
45
  clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=(8, 8))
46
  enhanced = clahe.apply(gray)
47
  save_debug_image(enhanced, "01_preprocess_clahe")
48
  # Gaussian blur to reduce noise
49
  blurred = cv2.GaussianBlur(enhanced, (3, 3), 0)
50
  save_debug_image(blurred, "02_preprocess_blur")
51
+ # Adaptive thresholding with dynamic block size
52
+ block_size = max(11, min(31, int(img.shape[0] / 15) * 2 + 1))
53
  thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
54
+ cv2.THRESH_BINARY_INV, block_size, 5)
55
+ # Morphological operations to enhance digits
56
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
57
  thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1)
58
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=2)
59
  save_debug_image(thresh, "03_preprocess_morph")
60
  return thresh, enhanced
61
 
 
63
  """Correct image rotation using edge detection."""
64
  try:
65
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
66
+ edges = cv2.Canny(gray, 50, 150, apertureSize=3)
67
  lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=30, maxLineGap=10)
68
  if lines is not None:
69
  angles = [np.arctan2(line[0][3] - line[0][1], line[0][2] - line[0][0]) * 180 / np.pi for line in lines]
 
81
  return img
82
 
83
  def detect_roi(img):
84
+ """Detect region of interest (display) with multi-scale contour filtering."""
85
  try:
86
  save_debug_image(img, "04_original")
87
  thresh, enhanced = preprocess_image(img)
88
  brightness_map = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
89
+ # Try multiple block sizes for robust ROI detection
90
+ block_sizes = [max(11, min(31, int(img.shape[0] / s) * 2 + 1)) for s in [15, 20, 25]]
91
+ valid_contours = []
92
+ img_area = img.shape[0] * img.shape[1]
93
 
94
+ for block_size in block_sizes:
95
+ temp_thresh = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
96
+ cv2.THRESH_BINARY_INV, block_size, 5)
97
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
98
+ temp_thresh = cv2.morphologyEx(temp_thresh, cv2.MORPH_CLOSE, kernel, iterations=2)
99
+ save_debug_image(temp_thresh, f"05_roi_threshold_block{block_size}")
100
+ contours, _ = cv2.findContours(temp_thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
101
+
102
  for c in contours:
103
  area = cv2.contourArea(c)
104
  x, y, w, h = cv2.boundingRect(c)
105
  roi_brightness = np.mean(brightness_map[y:y+h, x:x+w])
106
  aspect_ratio = w / h
107
+ if (300 < area < (img_area * 0.7) and
108
+ 0.5 <= aspect_ratio <= 10.0 and w > 60 and h > 25 and roi_brightness > 40):
 
109
  valid_contours.append((c, area * roi_brightness))
110
+ logging.debug(f"Contour (block {block_size}): Area={area}, Aspect={aspect_ratio:.2f}, Brightness={roi_brightness:.2f}")
111
+
112
+ if valid_contours:
113
+ contour, _ = max(valid_contours, key=lambda x: x[1])
114
+ x, y, w, h = cv2.boundingRect(contour)
115
+ padding = max(20, min(60, int(min(w, h) * 0.3)))
116
+ x, y = max(0, x - padding), max(0, y - padding)
117
+ w, h = min(w + 2 * padding, img.shape[1] - x), min(h + 2 * padding, img.shape[0] - y)
118
+ roi_img = img[y:y+h, x:x+w]
119
+ save_debug_image(roi_img, "06_detected_roi")
120
+ logging.info(f"Detected ROI: ({x}, {y}, {w}, {h})")
121
+ return roi_img, (x, y, w, h)
122
 
123
  logging.info("No ROI found, using full image.")
124
+ save_debug_image(img, "06_no_roi_fallback")
125
  return img, None
126
  except Exception as e:
127
  logging.error(f"ROI detection failed: {str(e)}")
128
+ save_debug_image(img, "06_roi_error_fallback")
129
  return img, None
130
 
131
+ def detect_segments(digit_img, brightness):
132
+ """Detect seven-segment digits with adaptive thresholds."""
133
+ try:
134
+ h, w = digit_img.shape
135
+ if h < 10 or w < 5:
136
+ logging.debug("Digit image too small for segment detection.")
137
+ return None
138
+
139
+ # Dynamic segment threshold based on brightness
140
+ segment_threshold = 0.2 if brightness < 80 else 0.3
141
+ segments = {
142
+ 'top': (int(w*0.1), int(w*0.9), 0, int(h*0.25)),
143
+ 'middle': (int(w*0.1), int(w*0.9), int(h*0.45), int(h*0.55)),
144
+ 'bottom': (int(w*0.1), int(w*0.9), int(h*0.75), h),
145
+ 'left_top': (0, int(w*0.3), int(h*0.1), int(h*0.5)),
146
+ 'left_bottom': (0, int(w*0.3), int(h*0.5), int(h*0.9)),
147
+ 'right_top': (int(w*0.7), w, int(h*0.1), int(h*0.5)),
148
+ 'right_bottom': (int(w*0.7), w, int(h*0.5), int(h*0.9))
149
+ }
150
+
151
+ segment_presence = {}
152
+ for name, (x1, x2, y1, y2) in segments.items():
153
+ x1, y1 = max(0, x1), max(0, y1)
154
+ x2, y2 = min(w, x2), min(h, y2)
155
+ region = digit_img[y1:y2, x1:x2]
156
+ if region.size == 0:
157
+ segment_presence[name] = False
158
+ continue
159
+ pixel_count = np.sum(region == 255)
160
+ total_pixels = region.size
161
+ segment_presence[name] = pixel_count / total_pixels > segment_threshold
162
+ logging.debug(f"Segment {name}: {pixel_count}/{total_pixels} = {pixel_count/total_pixels:.2f}")
163
+
164
+ digit_patterns = {
165
+ '0': ('top', 'bottom', 'left_top', 'left_bottom', 'right_top', 'right_bottom'),
166
+ '1': ('right_top', 'right_bottom'),
167
+ '2': ('top', 'middle', 'bottom', 'left_bottom', 'right_top'),
168
+ '3': ('top', 'middle', 'bottom', 'right_top', 'right_bottom'),
169
+ '4': ('middle', 'left_top', 'right_top', 'right_bottom'),
170
+ '5': ('top', 'middle', 'bottom', 'left_top', 'right_bottom'),
171
+ '6': ('top', 'middle', 'bottom', 'left_top', 'left_bottom', 'right_bottom'),
172
+ '7': ('top', 'right_top', 'right_bottom'),
173
+ '8': ('top', 'middle', 'bottom', 'left_top', 'left_bottom', 'right_top', 'right_bottom'),
174
+ '9': ('top', 'middle', 'bottom', 'left_top', 'right_top', 'right_bottom')
175
+ }
176
+
177
+ best_match, best_score = None, -1
178
+ for digit, pattern in digit_patterns.items():
179
+ matches = sum(1 for segment in pattern if segment_presence.get(segment, False))
180
+ non_matches = sum(1 for segment in segment_presence if segment not in pattern and segment_presence[segment])
181
+ score = matches - 0.2 * non_matches
182
+ if matches >= len(pattern) * 0.6:
183
+ score += 1.0
184
+ if score > best_score:
185
+ best_score = score
186
+ best_match = digit
187
+ logging.debug(f"Segment detection: {segment_presence}, Digit: {best_match}, Score: {best_score:.2f}")
188
+ return best_match
189
+ except Exception as e:
190
+ logging.error(f"Segment detection failed: {str(e)}")
191
+ return None
192
+
193
+ def perform_ocr(img, roi_bbox):
194
+ """Perform OCR with EasyOCR and seven-segment fallback."""
195
+ if easyocr_reader is None:
196
+ logging.error("EasyOCR not initialized, cannot perform OCR.")
197
  return None, 0.0
198
  try:
199
+ thresh, enhanced = preprocess_image(img)
200
+ brightness = estimate_brightness(img)
201
+ # Dynamic EasyOCR parameters
202
+ results = easyocr_reader.readtext(thresh, detail=1, paragraph=False,
203
+ contrast_ths=0.1, adjust_contrast=1.5,
204
+ text_threshold=0.3, mag_ratio=3.0,
205
+ allowlist='0123456789.', batch_size=1, y_ths=0.2)
206
+ save_debug_image(thresh, "07_ocr_threshold")
207
+ logging.info(f"EasyOCR results: {results}")
208
+
209
+ if not results:
210
+ logging.info("EasyOCR failed, trying fallback parameters.")
211
+ results = easyocr_reader.readtext(thresh, detail=1, paragraph=False,
212
+ contrast_ths=0.05, adjust_contrast=2.0,
213
+ text_threshold=0.2, mag_ratio=4.0,
214
+ allowlist='0123456789.', batch_size=1, y_ths=0.2)
215
+ save_debug_image(thresh, "07_fallback_threshold")
216
+
217
+ digits_info = []
218
+ for (bbox, text, conf) in results:
219
+ (x1, y1), (x2, y2), (x3, y3), (x4, y4) = bbox
220
+ h_bbox = max(y1, y2, y3, y4) - min(y1, y2, y3, y4)
221
+ if (text.isdigit() or text == '.') and h_bbox > 10 and conf > 0.2:
222
+ x_min, x_max = int(min(x1, x4)), int(max(x2, x3))
223
+ y_min, y_max = int(min(y1, y2)), int(max(y3, y4))
224
+ digits_info.append((x_min, x_max, y_min, y_max, text, conf))
225
+
226
+ if digits_info:
227
+ digits_info.sort(key=lambda x: x[0])
228
+ recognized_text = ""
229
+ total_conf = 0.0
230
+ conf_count = 0
231
+ for idx, (x_min, x_max, y_min, y_max, char, conf) in enumerate(digits_info):
232
+ x_min, y_min = max(0, x_min), max(0, y_min)
233
+ x_max, y_max = min(thresh.shape[1], x_max), min(thresh.shape[0], y_max)
234
+ if x_max <= x_min or y_max <= y_min:
235
+ continue
236
+ if conf < 0.7 and char != '.':
237
+ digit_crop = thresh[y_min:y_max, x_min:x_max]
238
+ save_debug_image(digit_crop, f"08_digit_crop_{idx}_{char}")
239
+ segment_digit = detect_segments(digit_crop, brightness)
240
+ if segment_digit:
241
+ recognized_text += segment_digit
242
+ total_conf += 0.85
243
+ logging.debug(f"Used segment detection for char {char}: {segment_digit}")
244
+ else:
245
+ recognized_text += char
246
+ total_conf += conf
247
+ conf_count += 1
248
+ else:
249
+ recognized_text += char
250
+ total_conf += conf
251
+ conf_count += 1
252
+
253
+ avg_conf = total_conf / conf_count if conf_count > 0 else 0.0
254
+ text = re.sub(r"[^\d\.]", "", recognized_text)
255
+ if text.count('.') > 1:
256
+ text = text.replace('.', '', text.count('.') - 1)
257
+ text = text.strip('.')
258
+ if text and re.fullmatch(r"^\d*\.?\d*$", text):
259
+ text = text.lstrip('0') or '0'
260
+ logging.info(f"Validated text: {text}, Confidence: {avg_conf:.2f}")
261
+ return text, avg_conf * 100
262
+ logging.info("No valid digits detected.")
263
  return None, 0.0
264
  except Exception as e:
265
  logging.error(f"OCR failed: {str(e)}")
 
279
  if roi_bbox:
280
  conf_threshold *= 1.1 if (roi_bbox[2] * roi_bbox[3]) > (img.shape[0] * img.shape[1] * 0.3) else 1.0
281
 
282
+ result, confidence = perform_ocr(roi_img, roi_bbox)
283
  if result and confidence >= conf_threshold * 100:
284
  try:
285
  weight = float(result)
286
+ if 0.01 <= weight <= 1000:
287
  logging.info(f"Detected weight: {result} kg, Confidence: {confidence:.2f}%")
288
  return result, confidence
289
  logging.warning(f"Weight {result} out of range.")
 
291
  logging.warning(f"Invalid weight format: {result}")
292
 
293
  logging.info("Primary OCR failed, using full image fallback.")
294
+ result, confidence = perform_ocr(img, None)
295
  if result and confidence >= conf_threshold * 0.9 * 100:
296
  try:
297
  weight = float(result)