Spaces:
Runtime error
Runtime error
File size: 1,728 Bytes
26d2f56 d0112ee 26d2f56 1fe5e1e d0112ee 26d2f56 1fe5e1e 86ee19a 26d2f56 d0112ee 26d2f56 e951858 26d2f56 d0112ee 6d168eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import json
from pprint import pprint
import torch
import torch.hub
from gradio import Interface, inputs, outputs
from PIL import Image
from torchvision import transforms
real_load = torch.hub.load_state_dict_from_url
def load_state_dict_from_url(*args, **kwargs):
kwargs["map_location"] = "cpu"
return real_load(*args, **kwargs)
torch.hub.load_state_dict_from_url = load_state_dict_from_url
model = torch.hub.load("RF5/danbooru-pretrained", "resnet50")
model.eval()
with open("./tags.json", "rt", encoding="utf-8") as f:
tags = json.load(f)
def main(input_image: Image.Image, threshold: float):
preprocess = transforms.Compose(
[
transforms.Resize(360),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]
),
]
)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(
0
) # create a mini-batch as expected by the model
with torch.no_grad():
output, *_ = model(input_batch)
probs = torch.sigmoid(output)
results = probs[probs > threshold]
inds = probs.argsort(descending=True)
tag_confidences = {}
for index in inds[0 : len(results)]:
tag_confidences[tags[index]] = float(probs[index].cpu().numpy())
pprint(tag_confidences)
return tag_confidences
image = inputs.Image(label="Upload your image here!", type="pil")
threshold = inputs.Slider(
label="Hide images confidence under", maximum=1, minimum=0, default=0.2
)
labels = outputs.Label(label="Tags", type="confidences")
interface = Interface(main, inputs=[image, threshold], outputs=[labels])
interface.launch()
|