File size: 7,768 Bytes
12f490a
4bac772
 
9ac8c85
 
4bac772
 
12f490a
4bac772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e82b84e
 
 
4bac772
 
 
 
 
 
 
 
 
 
 
 
81c7aed
 
4bac772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ac8c85
 
4bac772
e82b84e
9ac8c85
e284c36
12f490a
9ac8c85
 
 
4bac772
9ac8c85
4bac772
 
9ac8c85
4bac772
9ac8c85
e82b84e
9ac8c85
4bac772
 
 
 
12f490a
4bac772
e82b84e
81c7aed
4bac772
 
 
 
9ac8c85
4bac772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e82b84e
9ac8c85
4bac772
9ac8c85
4bac772
e82b84e
 
4bac772
 
e82b84e
 
4bac772
e82b84e
 
 
 
 
 
 
 
 
 
81c7aed
4bac772
e284c36
4bac772
e284c36
4bac772
 
 
b23377c
4bac772
 
e284c36
12f490a
4bac772
12f490a
9ac8c85
e284c36
9ac8c85
 
12f490a
 
 
4bac772
 
 
 
12f490a
 
4bac772
e82b84e
9ac8c85
e284c36
 
4bac772
e82b84e
e284c36
12f490a
4bac772
e82b84e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from ultralytics import YOLO
import gradio as gr

# Kiểm tra thiết bị
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Sử dụng thiết bị: {device}')

# Định nghĩa kiến trúc mô hình phân loại
class ClassificationModel(nn.Module):
    def __init__(self, num_classes=12):
        super(ClassificationModel, self).__init__()
        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)  # Sử dụng pretrained weights
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)  # Thay đổi lớp cuối cùng

    def forward(self, x):
        return self.model(x)

# Khởi tạo và tải mô hình phân loại
classification_model = ClassificationModel(num_classes=12)
classification_model.load_state_dict(torch.load('classification_state_dict.pt', map_location=device))
classification_model.to(device)
classification_model.eval()

# Tải mô hình YOLO
yolo_model = YOLO('best.pt')  # Đảm bảo rằng 'best.pt' nằm trong thư mục hiện tại

# Định nghĩa các lớp mụn
class_labels = [
    'acne_scars', 'blackhead', 'cystic', 'flat_wart', 'folliculitis',
    'keloid', 'milium', 'papular', 'purulent', 'sebo-crystan-conglo',
    'syringoma', 'whitehead'
]

# Ánh xạ tiếng Anh -> tiếng Việt
class_mapping = {
    'acne_scars': 'Sẹo mụn',
    'blackhead': 'Mụn đầu đen',
    'cystic': 'Mụn nang',
    'flat_wart': 'Mụn sần phẳng',
    'folliculitis': 'Viêm nang lông',
    'keloid': 'Mụn sẹo uốn',
    'milium': 'Mụn mili',
    'papular': 'Mụn nhỏ',
    'purulent': 'Mụn mủ',
    'sebo-crystan-conglo': 'Mụn bã đen kết tủa',
    'syringoma': 'Mụn nang mồ hôi',
    'whitehead': 'Mụn đầu trắng'
}

# Định nghĩa các biến đổi dữ liệu cho mô hình phân loại
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

def detect_and_classify(image, image_size=640, conf_threshold=0.4, iou_threshold=0.5):
    """
    Hàm này nhận vào một ảnh, phát hiện các vùng mụn bằng YOLO,
    phân loại từng vùng mụn bằng mô hình ResNet18, và trả về ảnh đã được
    annotate cùng với các thông tin liên quan.
    """
    # Mở ảnh và chuyển đổi sang RGB
    pil_image = Image.open(image).convert("RGB")

    # Dự đoán bằng YOLO
    results = yolo_model.predict(pil_image, conf=conf_threshold, iou=iou_threshold, imgsz=image_size)
    boxes = results[0].boxes
    num_boxes = len(boxes)

    if num_boxes == 0:
        severity = "Tốt"
        recommendation = "Làn da bạn khá ổn! Tiếp tục duy trì thói quen chăm sóc da."
        return pil_image, f"Tình trạng mụn: {severity}", recommendation, "Không phát hiện mụn."

    # Lấy thông tin bounding boxes
    xyxy = boxes.xyxy.detach().cpu().numpy().astype(int)  # Toạ độ bounding box
    confidences = boxes.conf.detach().cpu().numpy()
    class_ids = boxes.cls.detach().cpu().numpy().astype(int)

    # Chuẩn bị vẽ
    draw = ImageDraw.Draw(pil_image)
    try:
        font = ImageFont.truetype("arial.ttf", 15)
    except IOError:
        font = ImageFont.load_default()

    # Danh sách để lưu kết quả phân loại
    classified_results = []

    for i, (box, cls_id, conf) in enumerate(zip(xyxy, class_ids, confidences), start=1):
        x1, y1, x2, y2 = box
        class_name_en = class_labels[cls_id]
        class_name_vn = class_mapping.get(class_name_en, class_name_en)

        # Cắt crop vùng mụn
        crop = pil_image.crop((x1, y1, x2, y2))
        img_transformed = transform(crop).unsqueeze(0).to(device)

        with torch.no_grad():
            output = classification_model(img_transformed)
            probabilities = torch.softmax(output, dim=1)
            top_prob, top_class = probabilities.topk(1, dim=1)
            top_prob = top_prob.item()
            top_class = top_class.item()
            class_name_en = class_labels[top_class]
            class_name_vn = class_mapping.get(class_name_en, class_name_en)

        # Vẽ bounding box và nhãn
        label = f"#{i}: {class_name_en} ({class_name_vn}) ({top_prob:.2f})"
        # Sử dụng textbbox thay vì textsize
        text_bbox = draw.textbbox((0, 0), label, font=font)
        text_w = text_bbox[2] - text_bbox[0]
        text_h = text_bbox[3] - text_bbox[1]
        draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
        draw.rectangle([x1, y1 - text_h, x1 + text_w, y1], fill="red")
        draw.text((x1, y1 - text_h), label, fill="white", font=font)

        # Thêm kết quả phân loại vào danh sách
        classified_results.append((i, class_name_en, class_name_vn))

    # Đánh giá tình trạng da dựa trên số lượng mụn
    if num_boxes > 20:
        severity = "Nặng"
        recommendation = "Bạn nên đến gặp bác sĩ da liễu và sử dụng liệu trình trị mụn chuyên sâu."
    elif 10 <= num_boxes <= 20:
        severity = "Trung bình"
        recommendation = "Duy trì skincare đều đặn với sữa rửa mặt dịu nhẹ và dưỡng ẩm."
    else:
        severity = "Tốt"
        recommendation = "Làn da bạn khá ổn! Tiếp tục duy trì thói quen chăm sóc da."

    # Liệt kê loại mụn
    acne_types_str = "Danh sách mụn phát hiện:\n"
    for idx, cname_en, cname_vn in classified_results:
        acne_types_str += f"Mụn #{idx}: {cname_en} ({cname_vn})\n"

    return pil_image, f"Tình trạng mụn: {severity}", recommendation, acne_types_str

# Mô tả ứng dụng
description_md = """
## Ứng Dụng Nhận Diện và Phân Loại Mụn Bằng YOLO và ResNet18

1. **Phát hiện mụn:** Sử dụng mô hình YOLO để phát hiện các vùng mụn trên khuôn mặt.
2. **Phân loại mụn:** Sử dụng mô hình ResNet18 đã được huấn luyện để phân loại từng vùng mụn thành 12 loại khác nhau bao gồm: Sẹo mụn trứng cá, Mụn đầu đen, Mụn nang, Mụn phẳng, Viêm nang lông, Sẹo mùm, Mụn mili, Mụn sần nhỏ, Mụn mủ, Mụn bã đen kết tủa, Mụn nang mồ hôi, Mụn đầu trắng
3. **Hiển thị kết quả:** Ảnh sau khi xử lý sẽ hiển thị các bounding boxes, nhãn tiếng Anh và tiếng Việt của loại mụn, cùng với độ chính xác của mỗi phân loại.
4. **Đánh giá tình trạng da:** Cung cấp đánh giá tổng quát về tình trạng da và khuyến nghị tương ứng dựa trên số lượng mụn được phát hiện.
"""

# Định nghĩa giao diện Gradio
inputs = [
    gr.Image(type="filepath", label="Ảnh Khuôn Mặt"),
    gr.Slider(minimum=320, maximum=1280, step=32, value=640, label="Kích thước ảnh (Image Size)"),
    gr.Slider(minimum=0, maximum=1, step=0.05, value=0.4, label="Ngưỡng Confidence"),
    gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Ngưỡng IOU")
]

outputs = [
    gr.Image(type="pil", label="Ảnh Sau Khi Xử Lý"),
    gr.Textbox(label="Tình Trạng Mụn"),
    gr.Textbox(label="Khuyến Nghị"),
    gr.Textbox(label="Loại Mụn Phát Hiện")
]

# Tạo giao diện Gradio
app = gr.Interface(
    fn=detect_and_classify,
    inputs=inputs,
    outputs=outputs,
    title="YOLO + ResNet18 Phát Hiện và Phân Loại Mụn",
    description=description_md
)

# Khởi chạy ứng dụng
app.launch(share=True)