Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import json
|
|
|
2 |
|
3 |
-
import gradio
|
4 |
import torch
|
5 |
import torch.hub
|
6 |
-
from gradio import inputs, outputs
|
7 |
from PIL import Image
|
8 |
from torchvision import transforms
|
9 |
|
@@ -20,6 +20,7 @@ torch.hub.load_state_dict_from_url = load_state_dict_from_url
|
|
20 |
model = torch.hub.load(
|
21 |
"RF5/danbooru-pretrained",
|
22 |
"resnet50",
|
|
|
23 |
)
|
24 |
model.eval()
|
25 |
|
@@ -52,6 +53,7 @@ def main(input_image: Image.Image, threshold: float):
|
|
52 |
tag_confidences = {}
|
53 |
for index in inds[0 : len(results)]:
|
54 |
tag_confidences[tags[index]] = float(probs[index].cpu().numpy())
|
|
|
55 |
return tag_confidences
|
56 |
|
57 |
|
@@ -62,5 +64,5 @@ threshold = inputs.Slider(
|
|
62 |
|
63 |
labels = outputs.Label(type="confidence")
|
64 |
|
65 |
-
interface =
|
66 |
interface.launch()
|
|
|
1 |
import json
|
2 |
+
from pprint import pprint
|
3 |
|
|
|
4 |
import torch
|
5 |
import torch.hub
|
6 |
+
from gradio import Interface, inputs, outputs
|
7 |
from PIL import Image
|
8 |
from torchvision import transforms
|
9 |
|
|
|
20 |
model = torch.hub.load(
|
21 |
"RF5/danbooru-pretrained",
|
22 |
"resnet50",
|
23 |
+
map_location="cpu",
|
24 |
)
|
25 |
model.eval()
|
26 |
|
|
|
53 |
tag_confidences = {}
|
54 |
for index in inds[0 : len(results)]:
|
55 |
tag_confidences[tags[index]] = float(probs[index].cpu().numpy())
|
56 |
+
pprint(tag_confidences)
|
57 |
return tag_confidences
|
58 |
|
59 |
|
|
|
64 |
|
65 |
labels = outputs.Label(type="confidence")
|
66 |
|
67 |
+
interface = Interface(main, inputs=[image, threshold], outputs=[labels])
|
68 |
interface.launch()
|