Spaces:
Running
Running
Update ocr_engine.py
Browse files- ocr_engine.py +18 -9
ocr_engine.py
CHANGED
@@ -10,9 +10,15 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
10 |
# Set up logging
|
11 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
12 |
|
13 |
-
# Initialize TrOCR
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# Directory for debug images
|
18 |
DEBUG_DIR = "debug_images"
|
@@ -122,13 +128,16 @@ def detect_roi(img):
|
|
122 |
|
123 |
def perform_ocr(img):
|
124 |
"""Perform OCR using TrOCR for digital displays."""
|
|
|
|
|
|
|
125 |
try:
|
126 |
# Convert to PIL for TrOCR
|
127 |
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
128 |
save_debug_image(pil_img, "06_ocr_input")
|
129 |
# Process image with TrOCR
|
130 |
pixel_values = processor(pil_img, return_tensors="pt").pixel_values
|
131 |
-
generated_ids = model.generate(pixel_values)
|
132 |
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
133 |
logging.info(f"TrOCR raw output: {text}")
|
134 |
|
@@ -139,7 +148,7 @@ def perform_ocr(img):
|
|
139 |
text = text.strip('.')
|
140 |
if text and re.fullmatch(r"^\d*\.?\d*$", text):
|
141 |
text = text.lstrip('0') or '0'
|
142 |
-
confidence = 95.0 if len(text.replace('.', ''))
|
143 |
logging.info(f"Validated text: {text}, Confidence: {confidence:.2f}%")
|
144 |
return text, confidence
|
145 |
logging.info(f"Text '{text}' failed validation.")
|
@@ -156,17 +165,17 @@ def extract_weight_from_image(pil_img):
|
|
156 |
save_debug_image(img, "00_input_image")
|
157 |
img = correct_rotation(img)
|
158 |
brightness = estimate_brightness(img)
|
159 |
-
conf_threshold = 0.
|
160 |
|
161 |
roi_img, roi_bbox = detect_roi(img)
|
162 |
if roi_bbox:
|
163 |
-
conf_threshold *= 1.
|
164 |
|
165 |
result, confidence = perform_ocr(roi_img)
|
166 |
if result and confidence >= conf_threshold * 100:
|
167 |
try:
|
168 |
weight = float(result)
|
169 |
-
if 0.
|
170 |
logging.info(f"Detected weight: {result} kg, Confidence: {confidence:.2f}%")
|
171 |
return result, confidence
|
172 |
logging.warning(f"Weight {result} out of range.")
|
@@ -178,7 +187,7 @@ def extract_weight_from_image(pil_img):
|
|
178 |
if result and confidence >= conf_threshold * 0.9 * 100:
|
179 |
try:
|
180 |
weight = float(result)
|
181 |
-
if 0.
|
182 |
logging.info(f"Full image weight: {result} kg, Confidence: {confidence:.2f}%")
|
183 |
return result, confidence
|
184 |
logging.warning(f"Full image weight {result} out of range.")
|
|
|
10 |
# Set up logging
|
11 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
12 |
|
13 |
+
# Initialize TrOCR with error handling
|
14 |
+
try:
|
15 |
+
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
|
16 |
+
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
|
17 |
+
logging.info("TrOCR model and processor loaded successfully")
|
18 |
+
except Exception as e:
|
19 |
+
logging.error(f"Failed to load TrOCR model: {str(e)}")
|
20 |
+
processor = None
|
21 |
+
model = None
|
22 |
|
23 |
# Directory for debug images
|
24 |
DEBUG_DIR = "debug_images"
|
|
|
128 |
|
129 |
def perform_ocr(img):
|
130 |
"""Perform OCR using TrOCR for digital displays."""
|
131 |
+
if processor is None or model is None:
|
132 |
+
logging.error("TrOCR model not loaded, cannot perform OCR.")
|
133 |
+
return None, 0.0
|
134 |
try:
|
135 |
# Convert to PIL for TrOCR
|
136 |
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
137 |
save_debug_image(pil_img, "06_ocr_input")
|
138 |
# Process image with TrOCR
|
139 |
pixel_values = processor(pil_img, return_tensors="pt").pixel_values
|
140 |
+
generated_ids = model.generate(pixel_values, max_length=10)
|
141 |
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
142 |
logging.info(f"TrOCR raw output: {text}")
|
143 |
|
|
|
148 |
text = text.strip('.')
|
149 |
if text and re.fullmatch(r"^\d*\.?\d*$", text):
|
150 |
text = text.lstrip('0') or '0'
|
151 |
+
confidence = 95.0 if len(text.replace('.', '')) >= 2 else 85.0
|
152 |
logging.info(f"Validated text: {text}, Confidence: {confidence:.2f}%")
|
153 |
return text, confidence
|
154 |
logging.info(f"Text '{text}' failed validation.")
|
|
|
165 |
save_debug_image(img, "00_input_image")
|
166 |
img = correct_rotation(img)
|
167 |
brightness = estimate_brightness(img)
|
168 |
+
conf_threshold = 0.7 if brightness > 100 else 0.5
|
169 |
|
170 |
roi_img, roi_bbox = detect_roi(img)
|
171 |
if roi_bbox:
|
172 |
+
conf_threshold *= 1.1 if (roi_bbox[2] * roi_bbox[3]) > (img.shape[0] * img.shape[1] * 0.3) else 1.0
|
173 |
|
174 |
result, confidence = perform_ocr(roi_img)
|
175 |
if result and confidence >= conf_threshold * 100:
|
176 |
try:
|
177 |
weight = float(result)
|
178 |
+
if 0.01 <= weight <= 1000: # Narrowed range for typical scale weights
|
179 |
logging.info(f"Detected weight: {result} kg, Confidence: {confidence:.2f}%")
|
180 |
return result, confidence
|
181 |
logging.warning(f"Weight {result} out of range.")
|
|
|
187 |
if result and confidence >= conf_threshold * 0.9 * 100:
|
188 |
try:
|
189 |
weight = float(result)
|
190 |
+
if 0.01 <= weight <= 1000:
|
191 |
logging.info(f"Full image weight: {result} kg, Confidence: {confidence:.2f}%")
|
192 |
return result, confidence
|
193 |
logging.warning(f"Full image weight {result} out of range.")
|