mayhug's picture
Update app.py
e951858
raw
history blame
1.73 kB
import json
from pprint import pprint
import torch
import torch.hub
from gradio import Interface, inputs, outputs
from PIL import Image
from torchvision import transforms
real_load = torch.hub.load_state_dict_from_url
def load_state_dict_from_url(*args, **kwargs):
kwargs["map_location"] = "cpu"
return real_load(*args, **kwargs)
torch.hub.load_state_dict_from_url = load_state_dict_from_url
model = torch.hub.load("RF5/danbooru-pretrained", "resnet50")
model.eval()
with open("./tags.json", "rt", encoding="utf-8") as f:
tags = json.load(f)
def main(input_image: Image.Image, threshold: float):
preprocess = transforms.Compose(
[
transforms.Resize(360),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]
),
]
)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(
0
) # create a mini-batch as expected by the model
with torch.no_grad():
output, *_ = model(input_batch)
probs = torch.sigmoid(output)
results = probs[probs > threshold]
inds = probs.argsort(descending=True)
tag_confidences = {}
for index in inds[0 : len(results)]:
tag_confidences[tags[index]] = float(probs[index].cpu().numpy())
pprint(tag_confidences)
return tag_confidences
image = inputs.Image(label="Upload your image here!", type="pil")
threshold = inputs.Slider(
label="Hide images confidence under", maximum=1, minimum=0, default=0.2
)
labels = outputs.Label(label="Tags", type="confidences")
interface = Interface(main, inputs=[image, threshold], outputs=[labels])
interface.launch()