mayhug's picture
Update app.py
6d168eb
raw
history blame
1.45 kB
import json
import gradio
import torch
from gradio import inputs, outputs
from PIL import Image
from torchvision import transforms
model = torch.hub.load(
"RF5/danbooru-pretrained",
"resnet50",
map_location="cpu",
)
model.eval()
with open("./tags.json", "rt", encoding="utf-8") as f:
tags = json.load(f)
def main(input_image: Image.Image, threshold: float):
preprocess = transforms.Compose(
[
transforms.Resize(360),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]
),
]
)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(
0
) # create a mini-batch as expected by the model
with torch.no_grad():
output, *_ = model(input_batch)
probs = torch.sigmoid(output)
results = probs[probs > threshold]
inds = probs.argsort(descending=True)
tag_confidences = {}
for index in inds[0 : len(results)]:
tag_confidences[tags[index]] = float(probs[index].cpu().numpy())
return tag_confidences
image = inputs.Image(label="Upload your image here!", type="pil")
threshold = inputs.Slider(
label="Hide images confidence under", maximum=1, minimum=0, default=0.2
)
labels = outputs.Label(type="confidence")
interface = gradio.Interface(main, inputs=[image, threshold], outputs=[labels])
interface.launch()