Spaces:
Running
Running
File size: 9,806 Bytes
49d297d 46c59fe 49d297d 46c59fe 49d297d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import streamlit as st
import os
import numpy as np
import cv2
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from ultralytics import YOLO
import Levenshtein
# Page config
st.set_page_config(
page_title="Thai License Plate Detection",
page_icon="🚗",
layout="centered"
)
# Initialize session state for models
if 'models_loaded' not in st.session_state:
st.session_state['models_loaded'] = False
# Load models
@st.cache_resource
def load_models():
try:
# Load YOLO model first
yolo_model = YOLO('best.pt')
# Load TrOCR with specific model configuration
processor = TrOCRProcessor.from_pretrained(
'openthaigpt/thai-trocr',
use_auth_token=False,
trust_remote_code=True
)
# Load OCR model with specific configuration
ocr_model = VisionEncoderDecoderModel.from_pretrained(
'openthaigpt/thai-trocr',
use_auth_token=False,
trust_remote_code=True
)
# Move model to CPU if no GPU available
if not torch.cuda.is_available():
ocr_model = ocr_model.to('cpu')
return processor, ocr_model, yolo_model
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error("Detailed error information for debugging:")
import traceback
st.code(traceback.format_exc())
return None, None, None
# Thai provinces list
thai_provinces = [
"กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", "จันทบุรี", "ฉะเชิงเทรา",
"ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", "เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก",
"นครปฐม", "นครพนม", "นครราชสีมา", "นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ",
"บุรีรัมย์", "ปทุมธานี", "ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง",
"พิจิตร", "พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", "แม่ฮ่องสอน",
"ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", "ลำปาง", "ลำพูน", "เลย",
"ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", "สมุทรสาคร", "สระแก้ว", "สระบุรี",
"สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", "สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ",
"อุดรธานี", "อุทัยธานี", "อุบลราชธานี", "อ่างทอง"
]
def get_closest_province(input_text, provinces):
min_distance = float('inf')
closest_province = None
for province in provinces:
distance = Levenshtein.distance(input_text, province)
if distance < min_distance:
min_distance = distance
closest_province = province
return closest_province, min_distance
def process_image(image, processor, ocr_model, yolo_model):
CONF_THRESHOLD = 0.2
data = {"plate_number": "", "province": "", "raw_province": "", "plate_crop": None, "province_crop": None}
# Convert PIL Image to cv2 format
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Image enhancement
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
cl = clahe.apply(l)
enhanced = cv2.merge((cl,a,b))
image = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
# YOLO detection
results = yolo_model(image)
# Process detections
detections = []
for result in results:
for box in result.boxes:
confidence = float(box.conf)
class_id = int(box.cls.item())
if confidence < CONF_THRESHOLD:
continue
x1, y1, x2, y2 = map(int, box.xyxy.flatten())
detections.append((class_id, confidence, (x1, y1, x2, y2)))
# Sort by class_id
detections.sort(key=lambda x: x[0])
for class_id, confidence, (x1, y1, x2, y2) in detections:
cropped_image = image[y1:y2, x1:x2]
if cropped_image.size == 0:
continue
# Preprocess for OCR
cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
thresh_image = cv2.adaptiveThreshold(
cropped_image_gray,
255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV,
11,
2
)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2))
thresh_image = cv2.morphologyEx(thresh_image, cv2.MORPH_CLOSE, kernel)
cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB)
resized_image = cv2.resize(cropped_image_3d, (128, 32))
# OCR processing
pixel_values = processor(resized_image, return_tensors="pt").pixel_values
generated_ids = ocr_model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Convert crop to PIL for display
cropped_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
if class_id == 0: # License plate
data["plate_number"] = generated_text
data["plate_crop"] = cropped_pil
elif class_id == 1: # Province
generated_province, distance = get_closest_province(generated_text, thai_provinces)
data["raw_province"] = generated_text
data["province"] = generated_province
data["province_crop"] = cropped_pil
return data
# Main app
st.title("Thai License Plate Detection 🚗")
# Load models
try:
if not st.session_state['models_loaded']:
with st.spinner("Loading models... (this may take a minute)"):
processor, ocr_model, yolo_model = load_models()
st.session_state['models_loaded'] = True
st.session_state['processor'] = processor
st.session_state['ocr_model'] = ocr_model
st.session_state['yolo_model'] = yolo_model
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.stop()
# File uploader
uploaded_file = st.file_uploader("Upload an image of a Thai license plate", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
# Display the uploaded image
col1, col2 = st.columns(2)
with col1:
st.subheader("Uploaded Image")
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
# Process the image
with col2:
st.subheader("Detection Results")
with st.spinner("Processing image..."):
results = process_image(
image,
st.session_state['processor'],
st.session_state['ocr_model'],
st.session_state['yolo_model']
)
if results["plate_number"]:
st.success("Detection successful!")
st.write("📝 License Plate:", results['plate_number'])
if results['plate_crop'] is not None:
st.subheader("Cropped License Plate")
st.image(results['plate_crop'], caption="Detected License Plate Region")
if results['raw_province']:
st.write("🔍 Detected Province Text:", results['raw_province'])
if results['province']:
st.write("🏠 Matched Province:", results['province'])
else:
st.write("⚠️ No close province match found")
if results['province_crop'] is not None:
st.subheader("Cropped Province")
st.image(results['province_crop'], caption="Detected Province Region")
else:
st.write("⚠️ No province text detected")
else:
st.error("No license plate detected in the image.")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.markdown("---")
st.markdown("### Instructions")
st.markdown("""
1. Upload an image containing a Thai license plate
2. Wait for the processing to complete
3. View the detected license plate number and province
""")
# Add footer with GitHub link
st.markdown("---")
st.markdown("Made with ❤️ by [Your Name/Organization]")
st.markdown("Check out the [GitHub Repository](https://github.com/yourusername/your-repo) for more information") |