Spaces:
Sleeping
Sleeping
import torch | |
from torchvision import transforms | |
import gradio as gr | |
import timm | |
# Read the categories | |
with open("labels.txt", "r") as f: | |
categories = [s.strip() for s in f.readlines()] | |
model_ft = timm.create_model('vit_base_patch16_224_in21k', pretrained=True, num_classes=len(categories)) | |
model_path = 'best_cpu.pt' | |
model_ft.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
model_ft.eval() | |
# Download an example image from the pytorch website | |
# torch.hub.download_url_to_file("https://iiif.dl.itc.u-tokyo.ac.jp/iiif/genji/TIFF/A00_6587/01/01_0001.tif/full/,400/0/default.jpg", "examples/other.jpg") | |
# torch.hub.download_url_to_file("https://iiif.dl.itc.u-tokyo.ac.jp/iiif/genji/TIFF/A00_6587/01/01_0002.tif/full/,400/0/default.jpg", "examples/front.jpg") | |
# torch.hub.download_url_to_file("https://iiif.dl.itc.u-tokyo.ac.jp/iiif/genji/TIFF/A00_6587/01/01_0003.tif/full/,400/0/default.jpg", "examples/page.jpg") | |
# torch.hub.download_url_to_file("https://iiif.dl.itc.u-tokyo.ac.jp/iiif/genji/TIFF/A00_6587/01/01_0009.tif/full/,400/0/default.jpg", "examples/page2.jpg") | |
# torch.hub.download_url_to_file("https://iiif.dl.itc.u-tokyo.ac.jp/iiif/genji/TIFF/A00_6587/01/01_0032.tif/full/,400/0/default.jpg", "examples/back.jpg") | |
def inference(input_image): | |
preprocess = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
input_tensor = preprocess(input_image) | |
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model | |
# move the input and model to GPU for speed if available | |
if torch.cuda.is_available(): | |
input_batch = input_batch.to('cuda') | |
model_ft.to('cuda') | |
with torch.no_grad(): | |
output = model_ft(input_batch) | |
# The output has unnormalized scores. To get probabilities, you can run a softmax on it. | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
# Show top categories per image | |
top5_prob, top5_catid = torch.topk(probabilities, len(categories)) | |
result = {} | |
for i in range(top5_prob.size(0)): | |
result[categories[top5_catid[i]]] = top5_prob[i].item() | |
return result | |
inputs = gr.inputs.Image(type='pil') | |
outputs = gr.outputs.Label(type="confidences",num_top_classes=len(categories)) | |
title = "表紙・裏表紙・その他のページの分類" | |
description = "Vision Transformerを用いた表紙・裏表紙・その他のページの分類モデルです。" | |
article = "<p style='text-align: center'>次のデータセットを使用しました。<a href='' target='_blank'>あああ</a></p>" | |
examples = [ | |
['examples/other.jpg'], | |
['examples/front.jpg'], | |
["examples/page.jpg"], | |
["examples/page2.jpg"], | |
["examples/back.jpg"] | |
] | |
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch() |