Sanjayraju30 commited on
Commit
e790db4
·
verified ·
1 Parent(s): 3bd13bb

Update ocr_engine.py

Browse files
Files changed (1) hide show
  1. ocr_engine.py +18 -9
ocr_engine.py CHANGED
@@ -10,9 +10,15 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
- # Initialize TrOCR
14
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
15
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
 
 
 
 
 
 
16
 
17
  # Directory for debug images
18
  DEBUG_DIR = "debug_images"
@@ -122,13 +128,16 @@ def detect_roi(img):
122
 
123
  def perform_ocr(img):
124
  """Perform OCR using TrOCR for digital displays."""
 
 
 
125
  try:
126
  # Convert to PIL for TrOCR
127
  pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
128
  save_debug_image(pil_img, "06_ocr_input")
129
  # Process image with TrOCR
130
  pixel_values = processor(pil_img, return_tensors="pt").pixel_values
131
- generated_ids = model.generate(pixel_values)
132
  text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
133
  logging.info(f"TrOCR raw output: {text}")
134
 
@@ -139,7 +148,7 @@ def perform_ocr(img):
139
  text = text.strip('.')
140
  if text and re.fullmatch(r"^\d*\.?\d*$", text):
141
  text = text.lstrip('0') or '0'
142
- confidence = 95.0 if len(text.replace('.', '')) > 1 else 90.0
143
  logging.info(f"Validated text: {text}, Confidence: {confidence:.2f}%")
144
  return text, confidence
145
  logging.info(f"Text '{text}' failed validation.")
@@ -156,17 +165,17 @@ def extract_weight_from_image(pil_img):
156
  save_debug_image(img, "00_input_image")
157
  img = correct_rotation(img)
158
  brightness = estimate_brightness(img)
159
- conf_threshold = 0.6 if brightness > 100 else 0.4
160
 
161
  roi_img, roi_bbox = detect_roi(img)
162
  if roi_bbox:
163
- conf_threshold *= 1.2 if (roi_bbox[2] * roi_bbox[3]) > (img.shape[0] * img.shape[1] * 0.3) else 1.0
164
 
165
  result, confidence = perform_ocr(roi_img)
166
  if result and confidence >= conf_threshold * 100:
167
  try:
168
  weight = float(result)
169
- if 0.00001 <= weight <= 10000:
170
  logging.info(f"Detected weight: {result} kg, Confidence: {confidence:.2f}%")
171
  return result, confidence
172
  logging.warning(f"Weight {result} out of range.")
@@ -178,7 +187,7 @@ def extract_weight_from_image(pil_img):
178
  if result and confidence >= conf_threshold * 0.9 * 100:
179
  try:
180
  weight = float(result)
181
- if 0.00001 <= weight <= 10000:
182
  logging.info(f"Full image weight: {result} kg, Confidence: {confidence:.2f}%")
183
  return result, confidence
184
  logging.warning(f"Full image weight {result} out of range.")
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
+ # Initialize TrOCR with error handling
14
+ try:
15
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
16
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
17
+ logging.info("TrOCR model and processor loaded successfully")
18
+ except Exception as e:
19
+ logging.error(f"Failed to load TrOCR model: {str(e)}")
20
+ processor = None
21
+ model = None
22
 
23
  # Directory for debug images
24
  DEBUG_DIR = "debug_images"
 
128
 
129
  def perform_ocr(img):
130
  """Perform OCR using TrOCR for digital displays."""
131
+ if processor is None or model is None:
132
+ logging.error("TrOCR model not loaded, cannot perform OCR.")
133
+ return None, 0.0
134
  try:
135
  # Convert to PIL for TrOCR
136
  pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
137
  save_debug_image(pil_img, "06_ocr_input")
138
  # Process image with TrOCR
139
  pixel_values = processor(pil_img, return_tensors="pt").pixel_values
140
+ generated_ids = model.generate(pixel_values, max_length=10)
141
  text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
142
  logging.info(f"TrOCR raw output: {text}")
143
 
 
148
  text = text.strip('.')
149
  if text and re.fullmatch(r"^\d*\.?\d*$", text):
150
  text = text.lstrip('0') or '0'
151
+ confidence = 95.0 if len(text.replace('.', '')) >= 2 else 85.0
152
  logging.info(f"Validated text: {text}, Confidence: {confidence:.2f}%")
153
  return text, confidence
154
  logging.info(f"Text '{text}' failed validation.")
 
165
  save_debug_image(img, "00_input_image")
166
  img = correct_rotation(img)
167
  brightness = estimate_brightness(img)
168
+ conf_threshold = 0.7 if brightness > 100 else 0.5
169
 
170
  roi_img, roi_bbox = detect_roi(img)
171
  if roi_bbox:
172
+ conf_threshold *= 1.1 if (roi_bbox[2] * roi_bbox[3]) > (img.shape[0] * img.shape[1] * 0.3) else 1.0
173
 
174
  result, confidence = perform_ocr(roi_img)
175
  if result and confidence >= conf_threshold * 100:
176
  try:
177
  weight = float(result)
178
+ if 0.01 <= weight <= 1000: # Narrowed range for typical scale weights
179
  logging.info(f"Detected weight: {result} kg, Confidence: {confidence:.2f}%")
180
  return result, confidence
181
  logging.warning(f"Weight {result} out of range.")
 
187
  if result and confidence >= conf_threshold * 0.9 * 100:
188
  try:
189
  weight = float(result)
190
+ if 0.01 <= weight <= 1000:
191
  logging.info(f"Full image weight: {result} kg, Confidence: {confidence:.2f}%")
192
  return result, confidence
193
  logging.warning(f"Full image weight {result} out of range.")