TopdownAI's picture
Upload 3 files
453c0b9 verified
import os
import torch
import torch.nn as nn
import timm
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import gradio as gr
from PIL import Image
import torch.nn.functional as F
# 전역 설정
CFG = {
'IMG_SIZE': 224
}
class MultiLabelClassificationModel(nn.Module):
def __init__(self, num_labels):
super(MultiLabelClassificationModel, self).__init__()
# 이미지 특징 추출
self.cnn = timm.create_model("timm/convnext_base.clip_laion2b_augreg_ft_in12k_in1k", pretrained=True, drop_rate=0.05, drop_path_rate=0.05, in_chans=3)
# 멀티 라벨 분류 헤드
self.classification_head = nn.Linear(1000, num_labels)
def forward(self, images):
# CNN
features = self.cnn(images)
features_flat = features.view(features.size(0), -1)
# 멀티 라벨 분류
logits = self.classification_head(features_flat)
# probs = torch.sigmoid(logits)
return logits
test_transform = transforms.Compose([
transforms.Resize(size=(CFG['IMG_SIZE'], CFG['IMG_SIZE']), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
])
model = MultiLabelClassificationModel(num_labels=13)
model.load_state_dict(torch.load(f'checkpoint.tar')['model_state_dict'])
model.eval() # 모델을 평가 모드로 설정
# 미리 설정한 라벨 목록
labels = ['Mold', 'blight', 'greening', 'healthy', 'measles',
'mildew', 'mite', 'rot', 'rust', 'scab', 'scorch', 'spot', 'virus']
def predict(image_path):
image = Image.open(image_path)
image = test_transform(image).unsqueeze(0)
with torch.no_grad():
logits = model(image)
probs = F.softmax(logits, dim=1) # softmax를 적용하여 확률 값으로 변환
result = {label: float(probs[0][i]) for i, label in enumerate(labels)}
return result
app = gr.Interface(
fn=predict,
inputs=gr.Image(type='filepath'),
outputs=gr.Label(),
title='Multi-Label Image Classification',
description='Automatically classify images into the following categories: ' + ', '.join(labels) + '.'
)
app.launch(share=True)