mayhug commited on
Commit
d0112ee
·
1 Parent(s): d584e3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import json
 
2
 
3
- import gradio
4
  import torch
5
  import torch.hub
6
- from gradio import inputs, outputs
7
  from PIL import Image
8
  from torchvision import transforms
9
 
@@ -20,6 +20,7 @@ torch.hub.load_state_dict_from_url = load_state_dict_from_url
20
  model = torch.hub.load(
21
  "RF5/danbooru-pretrained",
22
  "resnet50",
 
23
  )
24
  model.eval()
25
 
@@ -52,6 +53,7 @@ def main(input_image: Image.Image, threshold: float):
52
  tag_confidences = {}
53
  for index in inds[0 : len(results)]:
54
  tag_confidences[tags[index]] = float(probs[index].cpu().numpy())
 
55
  return tag_confidences
56
 
57
 
@@ -62,5 +64,5 @@ threshold = inputs.Slider(
62
 
63
  labels = outputs.Label(type="confidence")
64
 
65
- interface = gradio.Interface(main, inputs=[image, threshold], outputs=[labels])
66
  interface.launch()
 
1
  import json
2
+ from pprint import pprint
3
 
 
4
  import torch
5
  import torch.hub
6
+ from gradio import Interface, inputs, outputs
7
  from PIL import Image
8
  from torchvision import transforms
9
 
 
20
  model = torch.hub.load(
21
  "RF5/danbooru-pretrained",
22
  "resnet50",
23
+ map_location="cpu",
24
  )
25
  model.eval()
26
 
 
53
  tag_confidences = {}
54
  for index in inds[0 : len(results)]:
55
  tag_confidences[tags[index]] = float(probs[index].cpu().numpy())
56
+ pprint(tag_confidences)
57
  return tag_confidences
58
 
59
 
 
64
 
65
  labels = outputs.Label(type="confidence")
66
 
67
+ interface = Interface(main, inputs=[image, threshold], outputs=[labels])
68
  interface.launch()