Spaces:
Runtime error
Runtime error
import json | |
import gradio | |
import torch | |
from gradio import inputs, outputs | |
from PIL import Image | |
from torchvision import transforms | |
model = torch.hub.load( | |
"RF5/danbooru-pretrained", | |
"resnet50", | |
map_location="cpu", | |
) | |
model.eval() | |
with open("./tags.json", "rt", encoding="utf-8") as f: | |
tags = json.load(f) | |
def main(input_image: Image.Image, threshold: float): | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize(360), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979] | |
), | |
] | |
) | |
input_tensor = preprocess(input_image) | |
input_batch = input_tensor.unsqueeze( | |
0 | |
) # create a mini-batch as expected by the model | |
with torch.no_grad(): | |
output, *_ = model(input_batch) | |
probs = torch.sigmoid(output) | |
results = probs[probs > threshold] | |
inds = probs.argsort(descending=True) | |
tag_confidences = {} | |
for index in inds[0 : len(results)]: | |
tag_confidences[tags[index]] = float(probs[index].cpu().numpy()) | |
return tag_confidences | |
image = inputs.Image(label="Upload your image here!", type="pil") | |
threshold = inputs.Slider( | |
label="Hide images confidence under", maximum=1, minimum=0, default=0.2 | |
) | |
labels = outputs.Label(type="confidence") | |
interface = gradio.Interface(main, inputs=[image, threshold], outputs=[labels]) | |
interface.launch() | |