File size: 3,239 Bytes
782cec7
cd4c90e
 
 
 
 
 
 
e5bb367
 
b237467
 
 
cd4c90e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7cb0f8
cd4c90e
 
 
 
 
 
 
 
 
b80c100
8a86d96
 
 
 
 
284fe0a
9c94b10
 
cd4c90e
 
 
 
 
 
 
 
 
 
b80c100
cd4c90e
 
 
 
 
 
 
 
b80c100
cd4c90e
 
 
 
 
 
 
b80c100
cd4c90e
 
9fbf078
 
 
 
 
 
 
 
 
e5bb367
 
 
 
 
 
 
 
9fbf078
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from io import BytesIO
from icevision import *
import collections
import PIL
import torch
import numpy as np
import torchvision

from classifier import transform_image

import icevision.models.ross.efficientdet

MODEL_TYPE = icevision.models.ross.efficientdet

def get_model(checkpoint_path):
    extra_args = {}
    backbone = MODEL_TYPE.backbones.d0
    # The efficientdet model requires an img_size parameter
    extra_args['img_size'] = 512

    model = MODEL_TYPE.model(backbone=backbone(pretrained=True),
                             num_classes=2, 
                             **extra_args)
    
    ckpt = get_checkpoint(checkpoint_path)
    model.load_state_dict(ckpt)

    return model

def get_checkpoint(checkpoint_path):
    ckpt = torch.load('checkpoint.ckpt', map_location=torch.device('cpu'))

    fixed_state_dict = collections.OrderedDict()

    for k, v in ckpt['state_dict'].items():
        new_k = k[6:]
        fixed_state_dict[new_k] = v

    return fixed_state_dict

def predict(model, image, detection_threshold):
    if isinstance(image, str):
        img = PIL.Image.open(image)
    else:
        img = PIL.Image.open(BytesIO(image))
        
    img = np.array(img)
    img = PIL.Image.fromarray(img)
    
    class_map = ClassMap(classes=['Waste'])
    transforms = tfms.A.Adapter([
                    *tfms.A.resize_and_pad(512),
                    tfms.A.Normalize()
                ])

    pred_dict  = MODEL_TYPE.end2end_detect(img,
                                           transforms, 
                                           model,
                                           class_map=class_map,
                                           detection_threshold=detection_threshold,
                                           return_as_pil_img=False,
                                           return_img=True,
                                           display_bbox=False,
                                           display_score=False,
                                           display_label=False)

    return pred_dict

def prepare_prediction(pred_dict, threshold):
    boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
    boxes = torch.stack(boxes)

    scores = torch.as_tensor(pred_dict['detection']['scores'])
    labels = torch.as_tensor(pred_dict['detection']['label_ids'])
    image = np.array(pred_dict['img'])

    fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold)
    boxes = boxes[fixed_boxes, :]

    return boxes, image

def predict_class(model, image, bboxes):
    preds = []

    for bbox in bboxes:
        img = image.copy()
        bbox = np.array(bbox).astype(int)
        cropped_img = PIL.Image.fromarray(img).crop(bbox)
        cropped_img = np.array(cropped_img)
        #cropped_img = torch.as_tensor(cropped_img, dtype=torch.float).unsqueeze(0)

        tran_image = transform_image(cropped_img, 224)
        tran_image = tran_image.transpose(2, 0, 1)
        tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
        print(tran_image.shape)
        y_preds = model(tran_image)
        preds.append(y_preds.softmax(1).detach().numpy())

    preds = np.concatenate(preds).argmax(1)

    return preds