File size: 4,164 Bytes
a56642d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from argparse import ArgumentParser

from mmdet.apis import inference_detector, init_detector, show_result_pyplot
from mmdet.core.mask.utils import encode_mask_results
import numpy as np
import mmcv
import torch 
from imantics import Polygons, Mask
import json
import os
import cv2, glob

class detections():
    def __init__(self, cfg_path, device, model_path = 'data/models/walt_vehicle.pth', threshold=0.85):
        self.model = init_detector(cfg_path, model_path, device=device)
        self.all_preds = []
        self.all_scores = []
        self.index = []
        self.score_thr = threshold
        self.result = []
        self.record_dict = {'model': cfg_path,'results': []}
        self.detect_count = []


    def run_on_image(self, image):
        self.result = inference_detector(self.model, image)
        image_labelled = self.model.show_result(image, self.result, score_thr=self.score_thr)
        return image_labelled

    def process_output(self, count):
        result = self.result
        infer_result = {'url': count,
                        'boxes': [],
                        'scores': [],
                        'keypoints': [],
                        'segmentation': [],
                        'label_ids': [],
                        'track': [],
                        'labels': []}

        if isinstance(result, tuple):
            bbox_result, segm_result = result
            #segm_result = encode_mask_results(segm_result)
            if isinstance(segm_result, tuple):
                segm_result = segm_result[0]  # ms rcnn
        bboxes = np.vstack(bbox_result)
        labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)]

        labels = np.concatenate(labels)
        segms = None
        if segm_result is not None and len(labels) > 0:  # non empty
            segms = mmcv.concat_list(segm_result)
            if isinstance(segms[0], torch.Tensor):
                segms = torch.stack(segms, dim=0).detach().cpu().numpy()
            else:
                segms = np.stack(segms, axis=0)

        for i, (bbox, label, segm) in enumerate(zip(bboxes, labels, segms)):
            if bbox[-1].item() <0.3:
                continue
            box = [bbox[0].item(), bbox[1].item(), bbox[2].item(), bbox[3].item()]
            polygons = Mask(segm).polygons()
            
            infer_result['boxes'].append(box)
            infer_result['segmentation'].append(polygons.segmentation)
            infer_result['scores'].append(bbox[-1].item())
            infer_result['labels'].append(self.model.CLASSES[label])
            infer_result['label_ids'].append(label)
        self.record_dict['results'].append(infer_result)
        self.detect_count = labels
    
    def write_json(self, filename):
        with open(filename + '.json', 'w') as f:
            json.dump(self.record_dict, f)

 
def main():
    if torch.cuda.is_available() == False:
        device='cpu'
    else:
        device='cuda:0'
    detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
    detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth')
    filenames = sorted(glob.glob('demo/images/*'))
    count = 0
    for filename in filenames:
        img=cv2.imread(filename)
        try:
            img = detect_people.run_on_image(img)
            img = detect.run_on_image(img)
        except:
            continue
        count=count+1

        try: 
            import os
            os.makedirs(os.path.dirname(filename.replace('demo','demo/results/')))
            os.mkdirs(os.path.dirname(filename))
        except:
            print('done')
        cv2.imwrite(filename.replace('demo','demo/results/'),img)
        if count == 30000:
            break
        try:
            detect.process_output(count)
        except:
            continue
    '''

    np.savez('FC', a= detect.record_dict)
    with open('check.json', 'w') as f:
        json.dump(detect.record_dict, f)
    detect.write_json('seq3')
    asas
    detect.process_output(0)
    '''
if __name__ == "__main__":
    main()