Alesteba commited on
Commit
865fc80
·
1 Parent(s): e2626bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -11,9 +11,11 @@ class_map = ClassMap(['raccoon','banana'])
11
 
12
  size = 384
13
 
14
- model_2 = models.torchvision.retinanet.model(
15
- backbone=models.torchvision.retinanet.backbones.resnet18_fpn (pretrained=True),
16
- num_classes=len(class_map),
 
 
17
  )
18
 
19
  # load from model_repo:
@@ -21,7 +23,7 @@ model_2 = models.torchvision.retinanet.model(
21
  # from huggingface_hub import hf_hub_download
22
  # hf_hub_download(repo_id="Alesteba/deep_model_02", filename="retinanet_racoon.pth")
23
 
24
- state_dict = torch.load('./retinanet_racoon.pth', map_location=torch.device('cpu'))
25
 
26
  model_2.load_state_dict(state_dict)
27
 
@@ -33,7 +35,7 @@ def predict(img):
33
 
34
  img = PIL.Image.fromarray(img, "RGB")
35
 
36
- pred_dict_2 = models.torchvision.retinanet.fastai.end2end_detect(
37
 
38
  img,
39
  infer_tfms,
 
11
 
12
  size = 384
13
 
14
+ model_type = models.mmdet.retinanet
15
+
16
+ model_2 = model_type.model(
17
+ backbone= model_type.backbones.swin_t_p4_w7_fpn_1x_coco (pretrained=True),
18
+ num_classes=len(class_map)
19
  )
20
 
21
  # load from model_repo:
 
23
  # from huggingface_hub import hf_hub_download
24
  # hf_hub_download(repo_id="Alesteba/deep_model_02", filename="retinanet_racoon.pth")
25
 
26
+ state_dict = torch.load('./mmdet_racoon.pth', map_location=torch.device('cpu'))
27
 
28
  model_2.load_state_dict(state_dict)
29
 
 
35
 
36
  img = PIL.Image.fromarray(img, "RGB")
37
 
38
+ pred_dict_2 = model_type.fastai.end2end_detect(
39
 
40
  img,
41
  infer_tfms,