import json from pprint import pprint import torch import torch.hub from gradio import Interface, inputs, outputs from PIL import Image from torchvision import transforms real_load = torch.hub.load_state_dict_from_url def load_state_dict_from_url(*args, **kwargs): kwargs["map_location"] = "cpu" return real_load(*args, **kwargs) torch.hub.load_state_dict_from_url = load_state_dict_from_url model = torch.hub.load("RF5/danbooru-pretrained", "resnet50") model.eval() with open("./tags.json", "rt", encoding="utf-8") as f: tags = json.load(f) def main(input_image: Image.Image, threshold: float): preprocess = transforms.Compose( [ transforms.Resize(360), transforms.ToTensor(), transforms.Normalize( mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979] ), ] ) input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze( 0 ) # create a mini-batch as expected by the model with torch.no_grad(): output, *_ = model(input_batch) probs = torch.sigmoid(output) results = probs[probs > threshold] inds = probs.argsort(descending=True) tag_confidences = {} for index in inds[0 : len(results)]: tag_confidences[tags[index]] = float(probs[index].cpu().numpy()) pprint(tag_confidences) return tag_confidences image = inputs.Image(label="Upload your image here!", type="pil") threshold = inputs.Slider( label="Hide images confidence under", maximum=1, minimum=0, default=0.2 ) labels = outputs.Label(label="Tags", type="confidences") interface = Interface(main, inputs=[image, threshold], outputs=[labels]) interface.launch()