Spaces:
Sleeping
Sleeping
File size: 10,880 Bytes
49d297d 46c59fe 49d297d 47587db 49d297d 46c59fe 53e83a9 47587db 53e83a9 46c59fe 47587db 53e83a9 47587db 46c59fe 53e83a9 46c59fe 47587db 46c59fe 49d297d 58bc27b 49d297d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
import streamlit as st
import os
import numpy as np
import cv2
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from ultralytics import YOLO
import Levenshtein
# Page config
st.set_page_config(
page_title="Thai License Plate Detection",
page_icon="🚗",
layout="centered"
)
# Initialize session state for models
if 'models_loaded' not in st.session_state:
st.session_state['models_loaded'] = False
def load_ocr_models():
"""Load OCR models with proper error handling"""
try:
# Set environment variables to suppress warnings
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# Load processor with specific config
processor = TrOCRProcessor.from_pretrained(
'openthaigpt/thai-trocr',
revision='main',
use_auth_token=False,
trust_remote_code=True,
local_files_only=False
)
# Load OCR model with specific config
ocr_model = VisionEncoderDecoderModel.from_pretrained(
'openthaigpt/thai-trocr',
revision='main',
use_auth_token=False,
trust_remote_code=True,
local_files_only=False
)
# Move model to CPU explicitly
ocr_model = ocr_model.to('cpu')
return processor, ocr_model
except Exception as e:
st.error(f"Error loading OCR models: {str(e)}")
st.error("Detailed error information:")
import traceback
st.code(traceback.format_exc())
return None, None
# Load models
@st.cache_resource
def load_models():
try:
# Check if YOLO weights exist
if not os.path.exists('best.pt'):
st.error("YOLO model weights (best.pt) not found in the current directory!")
return None, None, None
# Load YOLO model
try:
yolo_model = YOLO('best.pt', task='detect')
except Exception as yolo_error:
st.error(f"Error loading YOLO model: {str(yolo_error)}")
return None, None, None
# Load OCR models
processor, ocr_model = load_ocr_models()
if processor is None or ocr_model is None:
return None, None, None
return processor, ocr_model, yolo_model
except Exception as e:
st.error(f"Error in model loading: {str(e)}")
st.error("Detailed error information:")
import traceback
st.code(traceback.format_exc())
return None, None, None
# Thai provinces list
thai_provinces = [
"กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", "จันทบุรี", "ฉะเชิงเทรา",
"ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", "เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก",
"นครปฐม", "นครพนม", "นครราชสีมา", "นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ",
"บุรีรัมย์", "ปทุมธานี", "ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง",
"พิจิตร", "พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", "แม่ฮ่องสอน",
"ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", "ลำปาง", "ลำพูน", "เลย",
"ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", "สมุทรสาคร", "สระแก้ว", "สระบุรี",
"สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", "สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ",
"อุดรธานี", "อุทัยธานี", "อุบลราชธานี", "อ่างทอง"
]
def get_closest_province(input_text, provinces):
min_distance = float('inf')
closest_province = None
for province in provinces:
distance = Levenshtein.distance(input_text, province)
if distance < min_distance:
min_distance = distance
closest_province = province
return closest_province, min_distance
def process_image(image, processor, ocr_model, yolo_model):
CONF_THRESHOLD = 0.2
data = {"plate_number": "", "province": "", "raw_province": "", "plate_crop": None, "province_crop": None}
# Convert PIL Image to cv2 format
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Image enhancement
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
cl = clahe.apply(l)
enhanced = cv2.merge((cl,a,b))
image = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
# YOLO detection
results = yolo_model(image)
# Process detections
detections = []
for result in results:
for box in result.boxes:
confidence = float(box.conf)
class_id = int(box.cls.item())
if confidence < CONF_THRESHOLD:
continue
x1, y1, x2, y2 = map(int, box.xyxy.flatten())
detections.append((class_id, confidence, (x1, y1, x2, y2)))
# Sort by class_id
detections.sort(key=lambda x: x[0])
for class_id, confidence, (x1, y1, x2, y2) in detections:
cropped_image = image[y1:y2, x1:x2]
if cropped_image.size == 0:
continue
# Preprocess for OCR
cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
thresh_image = cv2.adaptiveThreshold(
cropped_image_gray,
255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV,
11,
2
)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2))
thresh_image = cv2.morphologyEx(thresh_image, cv2.MORPH_CLOSE, kernel)
cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB)
resized_image = cv2.resize(cropped_image_3d, (128, 32))
# OCR processing
pixel_values = processor(resized_image, return_tensors="pt").pixel_values
generated_ids = ocr_model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Convert crop to PIL for display
cropped_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
if class_id == 0: # License plate
data["plate_number"] = generated_text
data["plate_crop"] = cropped_pil
elif class_id == 1: # Province
generated_province, distance = get_closest_province(generated_text, thai_provinces)
data["raw_province"] = generated_text
data["province"] = generated_province
data["province_crop"] = cropped_pil
return data
# Main app
st.title("Thai License Plate Detection 🚗")
# Load models
try:
if not st.session_state['models_loaded']:
with st.spinner("Loading models... (this may take a minute)"):
processor, ocr_model, yolo_model = load_models()
st.session_state['models_loaded'] = True
st.session_state['processor'] = processor
st.session_state['ocr_model'] = ocr_model
st.session_state['yolo_model'] = yolo_model
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.stop()
# File uploader
uploaded_file = st.file_uploader("Upload an image of a Thai license plate", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
# Display the uploaded image
col1, col2 = st.columns(2)
with col1:
st.subheader("Uploaded Image")
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
# Process the image
with col2:
st.subheader("Detection Results")
with st.spinner("Processing image..."):
results = process_image(
image,
st.session_state['processor'],
st.session_state['ocr_model'],
st.session_state['yolo_model']
)
if results["plate_number"]:
st.success("Detection successful!")
st.write("📝 License Plate:", results['plate_number'])
if results['plate_crop'] is not None:
st.subheader("Cropped License Plate")
st.image(results['plate_crop'], caption="Detected License Plate Region")
if results['raw_province']:
st.write("🔍 Detected Province Text:", results['raw_province'])
if results['province']:
st.write("🏠 Matched Province:", results['province'])
else:
st.write("⚠️ No close province match found")
if results['province_crop'] is not None:
st.subheader("Cropped Province")
st.image(results['province_crop'], caption="Detected Province Region")
else:
st.write("⚠️ No province text detected")
else:
st.error("No license plate detected in the image.")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.markdown("---")
st.markdown("### Instructions")
st.markdown("""
1. Upload an image containing a Thai license plate
2. Wait for the processing to complete
3. View the detected license plate number and province
""")
# Add footer with GitHub link
st.markdown("---")
st.markdown("Made with ❤️ by [AI Research Group KMUTT]")
st.markdown("Check out the [GitHub Repository](https://github.com/yourusername/your-repo) for more information") |