ducdatit2002's picture
Update app.py
b23377c verified
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from ultralytics import YOLO
import gradio as gr
# Kiểm tra thiết bị
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Sử dụng thiết bị: {device}')
# Định nghĩa kiến trúc mô hình phân loại
class ClassificationModel(nn.Module):
def __init__(self, num_classes=12):
super(ClassificationModel, self).__init__()
self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) # Sử dụng pretrained weights
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Linear(num_ftrs, num_classes) # Thay đổi lớp cuối cùng
def forward(self, x):
return self.model(x)
# Khởi tạo và tải mô hình phân loại
classification_model = ClassificationModel(num_classes=12)
classification_model.load_state_dict(torch.load('classification_state_dict.pt', map_location=device))
classification_model.to(device)
classification_model.eval()
# Tải mô hình YOLO
yolo_model = YOLO('best.pt') # Đảm bảo rằng 'best.pt' nằm trong thư mục hiện tại
# Định nghĩa các lớp mụn
class_labels = [
'acne_scars', 'blackhead', 'cystic', 'flat_wart', 'folliculitis',
'keloid', 'milium', 'papular', 'purulent', 'sebo-crystan-conglo',
'syringoma', 'whitehead'
]
# Ánh xạ tiếng Anh -> tiếng Việt
class_mapping = {
'acne_scars': 'Sẹo mụn',
'blackhead': 'Mụn đầu đen',
'cystic': 'Mụn nang',
'flat_wart': 'Mụn sần phẳng',
'folliculitis': 'Viêm nang lông',
'keloid': 'Mụn sẹo uốn',
'milium': 'Mụn mili',
'papular': 'Mụn nhỏ',
'purulent': 'Mụn mủ',
'sebo-crystan-conglo': 'Mụn bã đen kết tủa',
'syringoma': 'Mụn nang mồ hôi',
'whitehead': 'Mụn đầu trắng'
}
# Định nghĩa các biến đổi dữ liệu cho mô hình phân loại
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
def detect_and_classify(image, image_size=640, conf_threshold=0.4, iou_threshold=0.5):
"""
Hàm này nhận vào một ảnh, phát hiện các vùng mụn bằng YOLO,
phân loại từng vùng mụn bằng mô hình ResNet18, và trả về ảnh đã được
annotate cùng với các thông tin liên quan.
"""
# Mở ảnh và chuyển đổi sang RGB
pil_image = Image.open(image).convert("RGB")
# Dự đoán bằng YOLO
results = yolo_model.predict(pil_image, conf=conf_threshold, iou=iou_threshold, imgsz=image_size)
boxes = results[0].boxes
num_boxes = len(boxes)
if num_boxes == 0:
severity = "Tốt"
recommendation = "Làn da bạn khá ổn! Tiếp tục duy trì thói quen chăm sóc da."
return pil_image, f"Tình trạng mụn: {severity}", recommendation, "Không phát hiện mụn."
# Lấy thông tin bounding boxes
xyxy = boxes.xyxy.detach().cpu().numpy().astype(int) # Toạ độ bounding box
confidences = boxes.conf.detach().cpu().numpy()
class_ids = boxes.cls.detach().cpu().numpy().astype(int)
# Chuẩn bị vẽ
draw = ImageDraw.Draw(pil_image)
try:
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
font = ImageFont.load_default()
# Danh sách để lưu kết quả phân loại
classified_results = []
for i, (box, cls_id, conf) in enumerate(zip(xyxy, class_ids, confidences), start=1):
x1, y1, x2, y2 = box
class_name_en = class_labels[cls_id]
class_name_vn = class_mapping.get(class_name_en, class_name_en)
# Cắt crop vùng mụn
crop = pil_image.crop((x1, y1, x2, y2))
img_transformed = transform(crop).unsqueeze(0).to(device)
with torch.no_grad():
output = classification_model(img_transformed)
probabilities = torch.softmax(output, dim=1)
top_prob, top_class = probabilities.topk(1, dim=1)
top_prob = top_prob.item()
top_class = top_class.item()
class_name_en = class_labels[top_class]
class_name_vn = class_mapping.get(class_name_en, class_name_en)
# Vẽ bounding box và nhãn
label = f"#{i}: {class_name_en} ({class_name_vn}) ({top_prob:.2f})"
# Sử dụng textbbox thay vì textsize
text_bbox = draw.textbbox((0, 0), label, font=font)
text_w = text_bbox[2] - text_bbox[0]
text_h = text_bbox[3] - text_bbox[1]
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
draw.rectangle([x1, y1 - text_h, x1 + text_w, y1], fill="red")
draw.text((x1, y1 - text_h), label, fill="white", font=font)
# Thêm kết quả phân loại vào danh sách
classified_results.append((i, class_name_en, class_name_vn))
# Đánh giá tình trạng da dựa trên số lượng mụn
if num_boxes > 20:
severity = "Nặng"
recommendation = "Bạn nên đến gặp bác sĩ da liễu và sử dụng liệu trình trị mụn chuyên sâu."
elif 10 <= num_boxes <= 20:
severity = "Trung bình"
recommendation = "Duy trì skincare đều đặn với sữa rửa mặt dịu nhẹ và dưỡng ẩm."
else:
severity = "Tốt"
recommendation = "Làn da bạn khá ổn! Tiếp tục duy trì thói quen chăm sóc da."
# Liệt kê loại mụn
acne_types_str = "Danh sách mụn phát hiện:\n"
for idx, cname_en, cname_vn in classified_results:
acne_types_str += f"Mụn #{idx}: {cname_en} ({cname_vn})\n"
return pil_image, f"Tình trạng mụn: {severity}", recommendation, acne_types_str
# Mô tả ứng dụng
description_md = """
## Ứng Dụng Nhận Diện và Phân Loại Mụn Bằng YOLO và ResNet18
1. **Phát hiện mụn:** Sử dụng mô hình YOLO để phát hiện các vùng mụn trên khuôn mặt.
2. **Phân loại mụn:** Sử dụng mô hình ResNet18 đã được huấn luyện để phân loại từng vùng mụn thành 12 loại khác nhau bao gồm: Sẹo mụn trứng cá, Mụn đầu đen, Mụn nang, Mụn phẳng, Viêm nang lông, Sẹo mùm, Mụn mili, Mụn sần nhỏ, Mụn mủ, Mụn bã đen kết tủa, Mụn nang mồ hôi, Mụn đầu trắng
3. **Hiển thị kết quả:** Ảnh sau khi xử lý sẽ hiển thị các bounding boxes, nhãn tiếng Anh và tiếng Việt của loại mụn, cùng với độ chính xác của mỗi phân loại.
4. **Đánh giá tình trạng da:** Cung cấp đánh giá tổng quát về tình trạng da và khuyến nghị tương ứng dựa trên số lượng mụn được phát hiện.
"""
# Định nghĩa giao diện Gradio
inputs = [
gr.Image(type="filepath", label="Ảnh Khuôn Mặt"),
gr.Slider(minimum=320, maximum=1280, step=32, value=640, label="Kích thước ảnh (Image Size)"),
gr.Slider(minimum=0, maximum=1, step=0.05, value=0.4, label="Ngưỡng Confidence"),
gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Ngưỡng IOU")
]
outputs = [
gr.Image(type="pil", label="Ảnh Sau Khi Xử Lý"),
gr.Textbox(label="Tình Trạng Mụn"),
gr.Textbox(label="Khuyến Nghị"),
gr.Textbox(label="Loại Mụn Phát Hiện")
]
# Tạo giao diện Gradio
app = gr.Interface(
fn=detect_and_classify,
inputs=inputs,
outputs=outputs,
title="YOLO + ResNet18 Phát Hiện và Phân Loại Mụn",
description=description_md
)
# Khởi chạy ứng dụng
app.launch(share=True)