|
import os |
|
import json |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
from io import BytesIO |
|
|
|
from vit_model import vit_base_patch16_224_in21k as create_model |
|
|
|
def classify_image(img): |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
data_transform = transforms.Compose( |
|
[transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
|
|
|
|
|
img = data_transform(img) |
|
|
|
img = torch.unsqueeze(img, dim=0) |
|
|
|
|
|
json_path = 'F:\mushroom_project\VIT\class_indices.json' |
|
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) |
|
|
|
with open(json_path, "r") as f: |
|
class_indict = json.load(f) |
|
|
|
|
|
model = create_model(num_classes=370, has_logits=False).to(device) |
|
|
|
model_weight_path = "F:\mushroom_project\VIT\pretrain_30_weights\\best_model.pth" |
|
|
|
|
|
model.load_state_dict(torch.load(model_weight_path, map_location=device)) |
|
model.eval() |
|
with torch.no_grad(): |
|
|
|
output = torch.squeeze(model(img.to(device))).cpu() |
|
predict = torch.softmax(output, dim=0) |
|
predict_cla = torch.argmax(predict).numpy() |
|
|
|
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], |
|
predict[predict_cla].numpy()) |
|
|
|
|
|
combined_list = list(zip(class_indict, predict)) |
|
|
|
|
|
sorted_combined_list = sorted(combined_list, key=lambda x: x[1], reverse=True) |
|
|
|
|
|
current_position = 5 |
|
|
|
|
|
|
|
start_index = max(current_position - 5, 0) |
|
previous_five = sorted_combined_list[start_index:current_position] |
|
|
|
joined_string = "" |
|
for i in previous_five: |
|
|
|
joined_string += ("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy())) + "\n" |
|
|
|
|
|
plt.title(joined_string) |
|
plt.tight_layout() |
|
fig = plt.figure() |
|
return joined_string |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_image, |
|
inputs=gr.Image(type='pil'), |
|
outputs=gr.Textbox(), |
|
title="Mushrrom Image Classification", |
|
description="Upload a mushroom image to classify." |
|
) |
|
|
|
|
|
|
|
iface.launch() |
|
|