Update app.py
Browse files
app.py
CHANGED
@@ -11,9 +11,11 @@ class_map = ClassMap(['raccoon','banana'])
|
|
11 |
|
12 |
size = 384
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
17 |
)
|
18 |
|
19 |
# load from model_repo:
|
@@ -21,7 +23,7 @@ model_2 = models.torchvision.retinanet.model(
|
|
21 |
# from huggingface_hub import hf_hub_download
|
22 |
# hf_hub_download(repo_id="Alesteba/deep_model_02", filename="retinanet_racoon.pth")
|
23 |
|
24 |
-
state_dict = torch.load('./
|
25 |
|
26 |
model_2.load_state_dict(state_dict)
|
27 |
|
@@ -33,7 +35,7 @@ def predict(img):
|
|
33 |
|
34 |
img = PIL.Image.fromarray(img, "RGB")
|
35 |
|
36 |
-
pred_dict_2 =
|
37 |
|
38 |
img,
|
39 |
infer_tfms,
|
|
|
11 |
|
12 |
size = 384
|
13 |
|
14 |
+
model_type = models.mmdet.retinanet
|
15 |
+
|
16 |
+
model_2 = model_type.model(
|
17 |
+
backbone= model_type.backbones.swin_t_p4_w7_fpn_1x_coco (pretrained=True),
|
18 |
+
num_classes=len(class_map)
|
19 |
)
|
20 |
|
21 |
# load from model_repo:
|
|
|
23 |
# from huggingface_hub import hf_hub_download
|
24 |
# hf_hub_download(repo_id="Alesteba/deep_model_02", filename="retinanet_racoon.pth")
|
25 |
|
26 |
+
state_dict = torch.load('./mmdet_racoon.pth', map_location=torch.device('cpu'))
|
27 |
|
28 |
model_2.load_state_dict(state_dict)
|
29 |
|
|
|
35 |
|
36 |
img = PIL.Image.fromarray(img, "RGB")
|
37 |
|
38 |
+
pred_dict_2 = model_type.fastai.end2end_detect(
|
39 |
|
40 |
img,
|
41 |
infer_tfms,
|