Sanjayraju30 commited on
Commit
dd1ae7e
·
verified ·
1 Parent(s): 5217dbe

Update ocr_engine.py

Browse files
Files changed (1) hide show
  1. ocr_engine.py +14 -11
ocr_engine.py CHANGED
@@ -1,25 +1,28 @@
 
1
  from PIL import Image
2
- from transformers import AutoProcessor, VisionEncoderDecoderModel
3
  import re
4
 
5
- # Load model fine-tuned for 7-segment displays
6
- processor = AutoProcessor.from_pretrained("roboflow/ocr-7segment")
7
- model = VisionEncoderDecoderModel.from_pretrained("roboflow/ocr-7segment")
8
 
9
  def extract_weight(image: Image.Image) -> str:
10
  image = image.convert("RGB")
11
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
12
- generated_ids = model.generate(pixel_values)
13
- full_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
14
 
15
- print("OCR Text:", full_text) # optional debug
 
 
16
 
17
- # Extract number (weight)
18
- match = re.search(r"(\d+(\.\d+)?)", full_text)
 
 
19
  weight = match.group(1) if match else None
20
 
21
  # Detect unit
22
- text_lower = full_text.lower().replace(" ", "")
23
  if any(u in text_lower for u in ["kg", "kgs", "kilogram", "kilo"]):
24
  unit = "kg"
25
  elif any(u in text_lower for u in ["g", "gram", "grams"]):
 
1
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
2
  from PIL import Image
3
+ import torch
4
  import re
5
 
6
+ # Load model
7
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
8
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
9
 
10
  def extract_weight(image: Image.Image) -> str:
11
  image = image.convert("RGB")
12
+ pixel_values = processor(image, return_tensors="pt").pixel_values
 
 
13
 
14
+ # Generate output
15
+ outputs = model.generate(pixel_values, max_length=512)
16
+ decoded = processor.batch_decode(outputs, skip_special_tokens=True)[0]
17
 
18
+ print("OCR Output:", decoded) # Optional for debug
19
+
20
+ # Extract number
21
+ match = re.search(r"(\d+(\.\d+)?)", decoded)
22
  weight = match.group(1) if match else None
23
 
24
  # Detect unit
25
+ text_lower = decoded.lower().replace(" ", "")
26
  if any(u in text_lower for u in ["kg", "kgs", "kilogram", "kilo"]):
27
  unit = "kg"
28
  elif any(u in text_lower for u in ["g", "gram", "grams"]):