Sanjayraju30 commited on
Commit
a4b646d
·
verified ·
1 Parent(s): 44745c5

Update ocr_engine.py

Browse files
Files changed (1) hide show
  1. ocr_engine.py +8 -10
ocr_engine.py CHANGED
@@ -1,9 +1,9 @@
1
  from transformers import DonutProcessor, VisionEncoderDecoderModel
2
  from PIL import Image
3
- import torch
4
  import re
 
5
 
6
- # Load model
7
  processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
8
  model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
9
 
@@ -11,21 +11,19 @@ def extract_weight(image: Image.Image) -> str:
11
  image = image.convert("RGB")
12
  pixel_values = processor(image, return_tensors="pt").pixel_values
13
 
14
- # Generate output
15
  outputs = model.generate(pixel_values, max_length=512)
16
  decoded = processor.batch_decode(outputs, skip_special_tokens=True)[0]
17
 
18
- print("OCR Output:", decoded) # Optional for debug
19
-
20
- # Extract number
21
- match = re.search(r"(\d+(\.\d+)?)", decoded)
22
  weight = match.group(1) if match else None
23
 
24
  # Detect unit
25
- text_lower = decoded.lower().replace(" ", "")
26
- if any(u in text_lower for u in ["kg", "kgs", "kilogram", "kilo"]):
27
  unit = "kg"
28
- elif any(u in text_lower for u in ["g", "gram", "grams"]):
29
  unit = "grams"
30
  else:
31
  unit = "kg" if weight and float(weight) >= 5 else "grams"
 
1
  from transformers import DonutProcessor, VisionEncoderDecoderModel
2
  from PIL import Image
 
3
  import re
4
+ import torch
5
 
6
+ # Load processor + model
7
  processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
8
  model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
9
 
 
11
  image = image.convert("RGB")
12
  pixel_values = processor(image, return_tensors="pt").pixel_values
13
 
14
+ # Generate text prediction
15
  outputs = model.generate(pixel_values, max_length=512)
16
  decoded = processor.batch_decode(outputs, skip_special_tokens=True)[0]
17
 
18
+ # Clean & extract weight
19
+ cleaned = decoded.lower().replace(" ", "")
20
+ match = re.search(r"(\d+(\.\d+)?)", cleaned)
 
21
  weight = match.group(1) if match else None
22
 
23
  # Detect unit
24
+ if any(u in cleaned for u in ["kg", "kgs", "kilogram", "kilo"]):
 
25
  unit = "kg"
26
+ elif any(u in cleaned for u in ["g", "gram", "grams"]):
27
  unit = "grams"
28
  else:
29
  unit = "kg" if weight and float(weight) >= 5 else "grams"