|
import torch |
|
import torch.nn as nn |
|
from torchvision import models, transforms |
|
from PIL import Image, ImageDraw, ImageFont |
|
import numpy as np |
|
from ultralytics import YOLO |
|
import gradio as gr |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f'Sử dụng thiết bị: {device}') |
|
|
|
|
|
class ClassificationModel(nn.Module): |
|
def __init__(self, num_classes=12): |
|
super(ClassificationModel, self).__init__() |
|
self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) |
|
num_ftrs = self.model.fc.in_features |
|
self.model.fc = nn.Linear(num_ftrs, num_classes) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
classification_model = ClassificationModel(num_classes=12) |
|
classification_model.load_state_dict(torch.load('classification_state_dict.pt', map_location=device)) |
|
classification_model.to(device) |
|
classification_model.eval() |
|
|
|
|
|
yolo_model = YOLO('best.pt') |
|
|
|
|
|
class_labels = [ |
|
'acne_scars', 'blackhead', 'cystic', 'flat_wart', 'folliculitis', |
|
'keloid', 'milium', 'papular', 'purulent', 'sebo-crystan-conglo', |
|
'syringoma', 'whitehead' |
|
] |
|
|
|
|
|
class_mapping = { |
|
'acne_scars': 'Sẹo mụn', |
|
'blackhead': 'Mụn đầu đen', |
|
'cystic': 'Mụn nang', |
|
'flat_wart': 'Mụn sần phẳng', |
|
'folliculitis': 'Viêm nang lông', |
|
'keloid': 'Mụn sẹo uốn', |
|
'milium': 'Mụn mili', |
|
'papular': 'Mụn nhỏ', |
|
'purulent': 'Mụn mủ', |
|
'sebo-crystan-conglo': 'Mụn bã đen kết tủa', |
|
'syringoma': 'Mụn nang mồ hôi', |
|
'whitehead': 'Mụn đầu trắng' |
|
} |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], |
|
[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def detect_and_classify(image, image_size=640, conf_threshold=0.4, iou_threshold=0.5): |
|
""" |
|
Hàm này nhận vào một ảnh, phát hiện các vùng mụn bằng YOLO, |
|
phân loại từng vùng mụn bằng mô hình ResNet18, và trả về ảnh đã được |
|
annotate cùng với các thông tin liên quan. |
|
""" |
|
|
|
pil_image = Image.open(image).convert("RGB") |
|
|
|
|
|
results = yolo_model.predict(pil_image, conf=conf_threshold, iou=iou_threshold, imgsz=image_size) |
|
boxes = results[0].boxes |
|
num_boxes = len(boxes) |
|
|
|
if num_boxes == 0: |
|
severity = "Tốt" |
|
recommendation = "Làn da bạn khá ổn! Tiếp tục duy trì thói quen chăm sóc da." |
|
return pil_image, f"Tình trạng mụn: {severity}", recommendation, "Không phát hiện mụn." |
|
|
|
|
|
xyxy = boxes.xyxy.detach().cpu().numpy().astype(int) |
|
confidences = boxes.conf.detach().cpu().numpy() |
|
class_ids = boxes.cls.detach().cpu().numpy().astype(int) |
|
|
|
|
|
draw = ImageDraw.Draw(pil_image) |
|
try: |
|
font = ImageFont.truetype("arial.ttf", 15) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
|
|
|
|
classified_results = [] |
|
|
|
for i, (box, cls_id, conf) in enumerate(zip(xyxy, class_ids, confidences), start=1): |
|
x1, y1, x2, y2 = box |
|
class_name_en = class_labels[cls_id] |
|
class_name_vn = class_mapping.get(class_name_en, class_name_en) |
|
|
|
|
|
crop = pil_image.crop((x1, y1, x2, y2)) |
|
img_transformed = transform(crop).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
output = classification_model(img_transformed) |
|
probabilities = torch.softmax(output, dim=1) |
|
top_prob, top_class = probabilities.topk(1, dim=1) |
|
top_prob = top_prob.item() |
|
top_class = top_class.item() |
|
class_name_en = class_labels[top_class] |
|
class_name_vn = class_mapping.get(class_name_en, class_name_en) |
|
|
|
|
|
label = f"#{i}: {class_name_en} ({class_name_vn}) ({top_prob:.2f})" |
|
|
|
text_bbox = draw.textbbox((0, 0), label, font=font) |
|
text_w = text_bbox[2] - text_bbox[0] |
|
text_h = text_bbox[3] - text_bbox[1] |
|
draw.rectangle([x1, y1, x2, y2], outline="red", width=2) |
|
draw.rectangle([x1, y1 - text_h, x1 + text_w, y1], fill="red") |
|
draw.text((x1, y1 - text_h), label, fill="white", font=font) |
|
|
|
|
|
classified_results.append((i, class_name_en, class_name_vn)) |
|
|
|
|
|
if num_boxes > 20: |
|
severity = "Nặng" |
|
recommendation = "Bạn nên đến gặp bác sĩ da liễu và sử dụng liệu trình trị mụn chuyên sâu." |
|
elif 10 <= num_boxes <= 20: |
|
severity = "Trung bình" |
|
recommendation = "Duy trì skincare đều đặn với sữa rửa mặt dịu nhẹ và dưỡng ẩm." |
|
else: |
|
severity = "Tốt" |
|
recommendation = "Làn da bạn khá ổn! Tiếp tục duy trì thói quen chăm sóc da." |
|
|
|
|
|
acne_types_str = "Danh sách mụn phát hiện:\n" |
|
for idx, cname_en, cname_vn in classified_results: |
|
acne_types_str += f"Mụn #{idx}: {cname_en} ({cname_vn})\n" |
|
|
|
return pil_image, f"Tình trạng mụn: {severity}", recommendation, acne_types_str |
|
|
|
|
|
description_md = """ |
|
## Ứng Dụng Nhận Diện và Phân Loại Mụn Bằng YOLO và ResNet18 |
|
|
|
1. **Phát hiện mụn:** Sử dụng mô hình YOLO để phát hiện các vùng mụn trên khuôn mặt. |
|
2. **Phân loại mụn:** Sử dụng mô hình ResNet18 đã được huấn luyện để phân loại từng vùng mụn thành 12 loại khác nhau bao gồm: Sẹo mụn trứng cá, Mụn đầu đen, Mụn nang, Mụn phẳng, Viêm nang lông, Sẹo mùm, Mụn mili, Mụn sần nhỏ, Mụn mủ, Mụn bã đen kết tủa, Mụn nang mồ hôi, Mụn đầu trắng |
|
3. **Hiển thị kết quả:** Ảnh sau khi xử lý sẽ hiển thị các bounding boxes, nhãn tiếng Anh và tiếng Việt của loại mụn, cùng với độ chính xác của mỗi phân loại. |
|
4. **Đánh giá tình trạng da:** Cung cấp đánh giá tổng quát về tình trạng da và khuyến nghị tương ứng dựa trên số lượng mụn được phát hiện. |
|
""" |
|
|
|
|
|
inputs = [ |
|
gr.Image(type="filepath", label="Ảnh Khuôn Mặt"), |
|
gr.Slider(minimum=320, maximum=1280, step=32, value=640, label="Kích thước ảnh (Image Size)"), |
|
gr.Slider(minimum=0, maximum=1, step=0.05, value=0.4, label="Ngưỡng Confidence"), |
|
gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Ngưỡng IOU") |
|
] |
|
|
|
outputs = [ |
|
gr.Image(type="pil", label="Ảnh Sau Khi Xử Lý"), |
|
gr.Textbox(label="Tình Trạng Mụn"), |
|
gr.Textbox(label="Khuyến Nghị"), |
|
gr.Textbox(label="Loại Mụn Phát Hiện") |
|
] |
|
|
|
|
|
app = gr.Interface( |
|
fn=detect_and_classify, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="YOLO + ResNet18 Phát Hiện và Phân Loại Mụn", |
|
description=description_md |
|
) |
|
|
|
|
|
app.launch(share=True) |
|
|