edgilr commited on
Commit
6548833
·
verified ·
1 Parent(s): e370ad2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -41
app.py CHANGED
@@ -1,47 +1,19 @@
1
- from gradio.outputs import Label
2
  from icevision.all import *
3
- from icevision.models.checkpoint import *
4
  import PIL
5
  import gradio as gr
6
- import os
7
 
8
- # Load model
9
- checkpoint_path = "fasterRCNNKangaroo.pth"
10
- checkpoint_and_model = model_from_checkpoint(checkpoint_path)
11
- model = checkpoint_and_model["model"]
12
- model_type = checkpoint_and_model["model_type"]
13
- class_map = checkpoint_and_model["class_map"]
14
 
15
- # Transforms
16
- img_size = checkpoint_and_model["img_size"]
17
- valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()])
 
 
 
18
 
19
- # Populate examples in Gradio interface
20
- examples = [
21
- ['00004.jpg'],
22
- ['00083.jpg'],
23
- ['00119.jpg']
24
- ]
25
-
26
- def show_preds(input_image):
27
- img = PIL.Image.fromarray(input_image, "RGB")
28
- pred_dict = model_type.end2end_detect(img, valid_tfms, model,
29
- class_map=class_map,
30
- detection_threshold=0.5,
31
- display_label=False,
32
- display_bbox=True,
33
- return_img=True,
34
- font_size=16,
35
- label_color="#FF59D6")
36
- return pred_dict["img"]
37
-
38
- gr_interface = gr.Interface(
39
- fn=show_preds,
40
- inputs=["image"],
41
- outputs=[gr.outputs.Image(type="pil", label="FasterRCNN Inference")],
42
- title="Kangaroo Object Detector",
43
- description="",
44
- examples=examples,
45
- )
46
- gr_interface.launch(inline=False, share=False, debug=True)
47
-
 
 
1
  from icevision.all import *
 
2
  import PIL
3
  import gradio as gr
 
4
 
5
+ class_map = ClassMap(['kangaroo'])
6
+ model = models.torchvision.faster_rcnn.model(backbone=models.torchvision.faster_rcnn.backbones.resnet50_fpn,
7
+ num_classes=len(class_map))
8
+ state_dict = torch.load('fasterRCNNKangaroo.pth')
9
+ model.load_state_dict(state_dict)
 
10
 
11
+ infer_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size),tfms.A.Normalize()])
12
+ size = 384
13
+ def predict(img):
14
+ img = PILImage.create(img)
15
+ pred_dict = models.torchvision.faster_rcnn.end2end_detect(img, infer_tfms, model.to("cpu"), class_map=class_map, detection_threshold=0.5)
16
+ return pred_dict['img']
17
 
18
+ # Creamos la interfaz y la lanzamos.
19
+ gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Label(num_top_classes=3),examples=['00004.jpg','00083.jpg', '00119.jpg']).launch(share=False)