danhsf commited on
Commit
2be9628
·
1 Parent(s): 4e6d303

fixing app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -6,27 +6,27 @@ from timm import create_model
6
  from timm.data import resolve_data_config
7
  from timm.data.transforms_factory import create_transform
8
 
9
- IMAGENET_1K_URL = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
10
- LABELS = requests.get(IMAGENET_1K_URL).text.strip().split('\n')
11
 
12
  model = create_model('resnet50', pretrained=True)
13
 
14
  transform = create_transform(
15
- **resolve_data_config({},model=model)
16
  )
17
  model.eval()
18
 
19
  def predict_fn(img):
20
  img = img.convert('RGB')
21
- img = transform(img).unsqueze(0)
22
 
23
- with torch._nograd():
24
  out = model(img)
 
 
25
 
26
- probabilities = torch.nn.functional.softmax(out[0], dim=0)
27
 
28
- values, indices = torch.topk(probabilities,k=5)
29
 
30
- return {LABELS[i]: v.item() for i,v in zip(indices,values)}
31
-
32
- gr.Interface(predict_fn,gr.inputs.Image(type='pil'), outputs='label').launch()
 
6
  from timm.data import resolve_data_config
7
  from timm.data.transforms_factory import create_transform
8
 
9
+ IMAGENET_1k_URL = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
10
+ LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
11
 
12
  model = create_model('resnet50', pretrained=True)
13
 
14
  transform = create_transform(
15
+ **resolve_data_config({}, model=model)
16
  )
17
  model.eval()
18
 
19
  def predict_fn(img):
20
  img = img.convert('RGB')
21
+ img = transform(img).unsqueeze(0)
22
 
23
+ with torch.no_grad():
24
  out = model(img)
25
+
26
+ probabilites = torch.nn.functional.softmax(out[0], dim=0)
27
 
28
+ values, indices = torch.topk(probabilites, k=5)
29
 
30
+ return {LABELS[i]: v.item() for i, v in zip(indices, values)}
31
 
32
+ gr.Interface(predict_fn, gr.inputs.Image(type='pil'), outputs='label').launch()