File size: 1,728 Bytes
26d2f56
d0112ee
26d2f56
 
1fe5e1e
d0112ee
26d2f56
 
 
1fe5e1e
 
 
 
 
 
 
 
 
 
86ee19a
26d2f56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0112ee
26d2f56
 
 
 
 
 
 
 
e951858
26d2f56
d0112ee
6d168eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import json
from pprint import pprint

import torch
import torch.hub
from gradio import Interface, inputs, outputs
from PIL import Image
from torchvision import transforms

real_load = torch.hub.load_state_dict_from_url


def load_state_dict_from_url(*args, **kwargs):
    kwargs["map_location"] = "cpu"
    return real_load(*args, **kwargs)


torch.hub.load_state_dict_from_url = load_state_dict_from_url

model = torch.hub.load("RF5/danbooru-pretrained", "resnet50")
model.eval()

with open("./tags.json", "rt", encoding="utf-8") as f:
    tags = json.load(f)


def main(input_image: Image.Image, threshold: float):
    preprocess = transforms.Compose(
        [
            transforms.Resize(360),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]
            ),
        ]
    )
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(
        0
    )  # create a mini-batch as expected by the model

    with torch.no_grad():
        output, *_ = model(input_batch)

    probs = torch.sigmoid(output)

    results = probs[probs > threshold]
    inds = probs.argsort(descending=True)
    tag_confidences = {}
    for index in inds[0 : len(results)]:
        tag_confidences[tags[index]] = float(probs[index].cpu().numpy())
    pprint(tag_confidences)
    return tag_confidences


image = inputs.Image(label="Upload your image here!", type="pil")
threshold = inputs.Slider(
    label="Hide images confidence under", maximum=1, minimum=0, default=0.2
)

labels = outputs.Label(label="Tags", type="confidences")

interface = Interface(main, inputs=[image, threshold], outputs=[labels])
interface.launch()