Cap_deploy / anpr.py
abhionair's picture
Upload 8 files
567a717 verified
raw
history blame
7.13 kB
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pytesseract as pt
import os
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests
import torch
import plotly.express as px
# LOAD YOLO MODEL
INPUT_WIDTH = 640
INPUT_HEIGHT = 640
#onnx_file_path = os.path.abspath('./static/models/best.onnx')
onnx_file_path = os.path.abspath(r'./static/model/best.onnx')
print(f"Attempting to load ONNX file from: {onnx_file_path}")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-printed')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-printed')
if not os.path.exists(onnx_file_path):
print(f"Error: ONNX file not found at {onnx_file_path}")
else:
try:
net = cv2.dnn.readNetFromONNX(onnx_file_path)
net.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
net.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)
except cv2.error as e:
print(f"Error loading ONNX file: {onnx_file_path}")
print(f"OpenCV error: {e}")
raise # Re-raise the exception to halt the program
def get_detections(img,net):
# CONVERT IMAGE TO YOLO FORMAT
image = img.copy()
row, col, d = image.shape
max_rc = max(row,col)
input_image = np.zeros((max_rc,max_rc,3),dtype=np.uint8)
input_image[0:row,0:col] = image
# GET PREDICTION FROM YOLO MODEL
blob = cv2.dnn.blobFromImage(input_image,1/255,(INPUT_WIDTH,INPUT_HEIGHT),swapRB=True,crop=False)
net.setInput(blob)
preds = net.forward()
detections = preds[0]
return input_image, detections
def non_maximum_supression(input_image,detections):
# FILTER DETECTIONS BASED ON CONFIDENCE AND PROBABILIY SCORE
# center x, center y, w , h, conf, proba
boxes = []
confidences = []
image_w, image_h = input_image.shape[:2]
x_factor = image_w/INPUT_WIDTH
y_factor = image_h/INPUT_HEIGHT
for i in range(len(detections)):
row = detections[i]
confidence = row[4] # confidence of detecting license plate
if confidence > 0.4:
class_score = row[5] # probability score of license plate
if class_score > 0.25:
cx, cy , w, h = row[0:4]
left = int((cx - 0.5*w)*x_factor)
top = int((cy-0.5*h)*y_factor)
width = int(w*x_factor)
height = int(h*y_factor)
box = np.array([left,top,width,height])
confidences.append(confidence)
boxes.append(box)
# clean
boxes_np = np.array(boxes).tolist()
confidences_np = np.array(confidences).tolist()
# NMS
index = np.array(cv2.dnn.NMSBoxes(boxes_np,confidences_np,0.25,0.45)).flatten()
return boxes_np, confidences_np, index
def extract_text_py(image,bbox):
x,y,w,h = bbox
roi = image[y:y+h, x:x+w]
if 0 in roi.shape:
return ''
else:
roi_bgr = cv2.cvtColor(roi,cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(roi_bgr,cv2.COLOR_BGR2GRAY)
magic_color = apply_brightness_contrast(gray,brightness=40,contrast=70)
#text = pt.image_to_string(magic_color)
text = pt.image_to_string(magic_color,lang='eng',config='--psm 6')
text = text.strip()
return text
# extrating text
def extract_text(image,bbox):
x,y,w,h = bbox
roi = image[y:y+h, x:x+w]
#print("roi:",roi)
# Use OpenCV to read the image
img = roi.copy()
print(img.shape)
# Convert BGR to RGB
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Create the Plotly Express figure
#fig = px.imshow(img_rgb)
# Update layout and show the figure
#fig.update_layout(width=100, height=40, margin=dict(l=10, r=10, b=10, t=10))
#fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
#fig.show()
image = img_rgb
if 0 in roi.shape:
return 'no number'
else:
pixel_values = processor(images=image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
text = filter_string(text)
return text
def filter_string(input_string):
filtered_chars = [char for char in input_string if char.isalnum() and (char.isupper() or char.isdigit())]
filtered_string = ''.join(filtered_chars)
return filtered_string
def drawings(image,boxes_np,confidences_np,index):
# drawings
text_list = []
for ind in index:
x,y,w,h = boxes_np[ind]
bb_conf = confidences_np[ind]
conf_text = 'plate: {:.0f}%'.format(bb_conf*100)
license_text = extract_text(image,boxes_np[ind])
cv2.rectangle(image,(x,y),(x+w,y+h),(255,0,255),2)
cv2.rectangle(image,(x,y-30),(x+w,y),(255,0,255),-1)
cv2.rectangle(image,(x,y+h),(x+w,y+h+30),(0,0,0),-1)
cv2.putText(image,conf_text,(x,y-10),cv2.FONT_HERSHEY_SIMPLEX,0.7,(255,255,255),1)
cv2.putText(image,license_text,(x,y+h+27),cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,255,0),1)
text_list.append(license_text)
return image, text_list
# predictions
def yolo_predictions(img,net):
## step-1: detections
input_image, detections = get_detections(img,net)
## step-2: NMS
boxes_np, confidences_np, index = non_maximum_supression(input_image, detections)
## step-3: Drawings
result_img, text = drawings(img,boxes_np,confidences_np,index)
return result_img, text
def object_detection(path,filename):
# read image
image = cv2.imread(path) # PIL object
image = np.array(image,dtype=np.uint8) # 8 bit array (0,255)
result_img, text_list = yolo_predictions(image,net)
cv2.imwrite('./static/predict/{}'.format(filename),result_img)
return text_list
# def OCR(path,filename):
# img = np.array(load_img(path))
# cods = object_detection(path,filename)
# xmin ,xmax,ymin,ymax = cods[0]
# roi = img[ymin:ymax,xmin:xmax]
# roi_bgr = cv2.cvtColor(roi,cv2.COLOR_RGB2BGR)
# gray = cv2.cvtColor(roi_bgr,cv2.COLOR_BGR2GRAY)
# magic_color = apply_brightness_contrast(gray,brightness=40,contrast=70)
# cv2.imwrite('./static/roi/{}'.format(filename),roi_bgr)
# print(text)
# save_text(filename,text)
# return text
def apply_brightness_contrast(input_img, brightness = 0, contrast = 0):
if brightness != 0:
if brightness > 0:
shadow = brightness
highlight = 255
else:
shadow = 0
highlight = 255 + brightness
alpha_b = (highlight - shadow)/255
gamma_b = shadow
buf = cv2.addWeighted(input_img, alpha_b, input_img, 0, gamma_b)
else:
buf = input_img.copy()
if contrast != 0:
f = 131*(contrast + 127)/(127*(131-contrast))
alpha_c = f
gamma_c = 127*(1-f)
buf = cv2.addWeighted(buf, alpha_c, buf, 0, gamma_c)
return buf