File size: 10,880 Bytes
49d297d
 
 
 
 
46c59fe
49d297d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47587db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49d297d
 
 
46c59fe
53e83a9
 
 
 
 
47587db
53e83a9
 
 
 
 
46c59fe
47587db
 
 
53e83a9
47587db
46c59fe
53e83a9
46c59fe
47587db
 
46c59fe
 
 
49d297d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58bc27b
49d297d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import streamlit as st
import os
import numpy as np
import cv2
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from ultralytics import YOLO
import Levenshtein

# Page config
st.set_page_config(
    page_title="Thai License Plate Detection",
    page_icon="🚗",
    layout="centered"
)

# Initialize session state for models
if 'models_loaded' not in st.session_state:
    st.session_state['models_loaded'] = False

def load_ocr_models():
    """Load OCR models with proper error handling"""
    try:
        # Set environment variables to suppress warnings
        os.environ['TOKENIZERS_PARALLELISM'] = 'false'
        
        # Load processor with specific config
        processor = TrOCRProcessor.from_pretrained(
            'openthaigpt/thai-trocr',
            revision='main',
            use_auth_token=False,
            trust_remote_code=True,
            local_files_only=False
        )
        
        # Load OCR model with specific config
        ocr_model = VisionEncoderDecoderModel.from_pretrained(
            'openthaigpt/thai-trocr',
            revision='main',
            use_auth_token=False,
            trust_remote_code=True,
            local_files_only=False
        )
        
        # Move model to CPU explicitly
        ocr_model = ocr_model.to('cpu')
        
        return processor, ocr_model
    except Exception as e:
        st.error(f"Error loading OCR models: {str(e)}")
        st.error("Detailed error information:")
        import traceback
        st.code(traceback.format_exc())
        return None, None

# Load models
@st.cache_resource
def load_models():
    try:
        # Check if YOLO weights exist
        if not os.path.exists('best.pt'):
            st.error("YOLO model weights (best.pt) not found in the current directory!")
            return None, None, None

        # Load YOLO model
        try:
            yolo_model = YOLO('best.pt', task='detect')
        except Exception as yolo_error:
            st.error(f"Error loading YOLO model: {str(yolo_error)}")
            return None, None, None
        
        # Load OCR models
        processor, ocr_model = load_ocr_models()
        if processor is None or ocr_model is None:
            return None, None, None
            
        return processor, ocr_model, yolo_model
        
    except Exception as e:
        st.error(f"Error in model loading: {str(e)}")
        st.error("Detailed error information:")
        import traceback
        st.code(traceback.format_exc())
        return None, None, None

# Thai provinces list
thai_provinces = [
    "กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", "จันทบุรี", "ฉะเชิงเทรา",
    "ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", "เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก",
    "นครปฐม", "นครพนม", "นครราชสีมา", "นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ",
    "บุรีรัมย์", "ปทุมธานี", "ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง",
    "พิจิตร", "พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", "แม่ฮ่องสอน",
    "ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", "ลำปาง", "ลำพูน", "เลย",
    "ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", "สมุทรสาคร", "สระแก้ว", "สระบุรี",
    "สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", "สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ",
    "อุดรธานี", "อุทัยธานี", "อุบลราชธานี", "อ่างทอง"
]

def get_closest_province(input_text, provinces):
    min_distance = float('inf')
    closest_province = None
    for province in provinces:
        distance = Levenshtein.distance(input_text, province)
        if distance < min_distance:
            min_distance = distance
            closest_province = province
    return closest_province, min_distance

def process_image(image, processor, ocr_model, yolo_model):
    CONF_THRESHOLD = 0.2
    data = {"plate_number": "", "province": "", "raw_province": "", "plate_crop": None, "province_crop": None}
    
    # Convert PIL Image to cv2 format
    image = np.array(image)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    
    # Image enhancement
    lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
    cl = clahe.apply(l)
    enhanced = cv2.merge((cl,a,b))
    image = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)

    # YOLO detection
    results = yolo_model(image)
    
    # Process detections
    detections = []
    for result in results:
        for box in result.boxes:
            confidence = float(box.conf)
            class_id = int(box.cls.item())
            if confidence < CONF_THRESHOLD:
                continue
            x1, y1, x2, y2 = map(int, box.xyxy.flatten())
            detections.append((class_id, confidence, (x1, y1, x2, y2)))

    # Sort by class_id
    detections.sort(key=lambda x: x[0])
    
    for class_id, confidence, (x1, y1, x2, y2) in detections:
        cropped_image = image[y1:y2, x1:x2]
        if cropped_image.size == 0:
            continue
            
        # Preprocess for OCR
        cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
        thresh_image = cv2.adaptiveThreshold(
            cropped_image_gray,
            255,
            cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
            cv2.THRESH_BINARY_INV,
            11,
            2
        )
        
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2))
        thresh_image = cv2.morphologyEx(thresh_image, cv2.MORPH_CLOSE, kernel)
        cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB)
        resized_image = cv2.resize(cropped_image_3d, (128, 32))
        
        # OCR processing
        pixel_values = processor(resized_image, return_tensors="pt").pixel_values
        generated_ids = ocr_model.generate(pixel_values)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        # Convert crop to PIL for display
        cropped_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))

        if class_id == 0:  # License plate
            data["plate_number"] = generated_text
            data["plate_crop"] = cropped_pil
        elif class_id == 1:  # Province
            generated_province, distance = get_closest_province(generated_text, thai_provinces)
            data["raw_province"] = generated_text
            data["province"] = generated_province
            data["province_crop"] = cropped_pil
            
    return data

# Main app
st.title("Thai License Plate Detection 🚗")

# Load models
try:
    if not st.session_state['models_loaded']:
        with st.spinner("Loading models... (this may take a minute)"):
            processor, ocr_model, yolo_model = load_models()
            st.session_state['models_loaded'] = True
            st.session_state['processor'] = processor
            st.session_state['ocr_model'] = ocr_model
            st.session_state['yolo_model'] = yolo_model
except Exception as e:
    st.error(f"Error loading models: {str(e)}")
    st.stop()

# File uploader
uploaded_file = st.file_uploader("Upload an image of a Thai license plate", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    try:
        # Display the uploaded image
        col1, col2 = st.columns(2)
        with col1:
            st.subheader("Uploaded Image")
            image = Image.open(uploaded_file)
            st.image(image, use_column_width=True)
        
        # Process the image
        with col2:
            st.subheader("Detection Results")
            with st.spinner("Processing image..."):
                results = process_image(
                    image,
                    st.session_state['processor'],
                    st.session_state['ocr_model'],
                    st.session_state['yolo_model']
                )
                
                if results["plate_number"]:
                    st.success("Detection successful!")
                    st.write("📝 License Plate:", results['plate_number'])
                    
                    if results['plate_crop'] is not None:
                        st.subheader("Cropped License Plate")
                        st.image(results['plate_crop'], caption="Detected License Plate Region")
                    
                    if results['raw_province']:
                        st.write("🔍 Detected Province Text:", results['raw_province'])
                        if results['province']:
                            st.write("🏠 Matched Province:", results['province'])
                        else:
                            st.write("⚠️ No close province match found")
                        
                        if results['province_crop'] is not None:
                            st.subheader("Cropped Province")
                            st.image(results['province_crop'], caption="Detected Province Region")
                    else:
                        st.write("⚠️ No province text detected")
                else:
                    st.error("No license plate detected in the image.")
                    
    except Exception as e:
        st.error(f"An error occurred: {str(e)}")

st.markdown("---")
st.markdown("### Instructions")
st.markdown("""
1. Upload an image containing a Thai license plate
2. Wait for the processing to complete
3. View the detected license plate number and province
""")

# Add footer with GitHub link
st.markdown("---")
st.markdown("Made with ❤️ by [AI Research Group KMUTT]")
st.markdown("Check out the [GitHub Repository](https://github.com/yourusername/your-repo) for more information")