ulichovick commited on
Commit
dd43163
·
verified ·
1 Parent(s): 6f0ea92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
5
 
6
  with open("birds.csv", "r") as r:
7
  birds = list(csv.reader(r, delimiter=","))
 
8
 
9
  config = {"num_channels": 3, "hidden_units": 256, "num_classes": 525, "labels": birds}
10
 
@@ -20,7 +21,7 @@ def predice(imagen):
20
  transforms.Resize(size=(224, 224)),
21
  ])
22
 
23
- target_image = io.read_image(imagen).type(torch.float32)
24
  target_image = target_image / 255.
25
  target_image = usr_img_transform(target_image)
26
 
@@ -29,12 +30,12 @@ def predice(imagen):
29
  target_image = target_image.unsqueeze(dim=0)
30
  target_image_pred = model(target_image)
31
  target_image_pred_label = torch.argmax(target_image_pred, dim=1)
32
- label = config["labels"][target_image_pred_label]
33
- return label
34
 
35
  gr.Interface(
36
  predice,
37
- inputs=gr.Image(label="gimme da bird"),
38
  outputs="label",
39
  title=titulo,
40
  description=desc,
 
5
 
6
  with open("birds.csv", "r") as r:
7
  birds = list(csv.reader(r, delimiter=","))
8
+ birds = birds[0]
9
 
10
  config = {"num_channels": 3, "hidden_units": 256, "num_classes": 525, "labels": birds}
11
 
 
21
  transforms.Resize(size=(224, 224)),
22
  ])
23
 
24
+ target_image = io.read_image(str(imagen)).type(torch.float32)
25
  target_image = target_image / 255.
26
  target_image = usr_img_transform(target_image)
27
 
 
30
  target_image = target_image.unsqueeze(dim=0)
31
  target_image_pred = model(target_image)
32
  target_image_pred_label = torch.argmax(target_image_pred, dim=1)
33
+ label = config["labels"][target_image_pred_label.item()]
34
+ return str(label)
35
 
36
  gr.Interface(
37
  predice,
38
+ inputs=gr.Image(label="gimme da bird", type="filepath"),
39
  outputs="label",
40
  title=titulo,
41
  description=desc,