File size: 3,859 Bytes
1c3eb47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
from typing import Any

import mmengine
import torch
import torch.nn as nn
from einops import rearrange

from mmdet.models.utils import samplelist_boxtype2tensor
from mmdet.structures import SampleList
from mmdet.utils import InstanceList
from mmpl.registry import MODELS
from ..builder import build_backbone, build_loss, build_neck, build_head
from .base_pler import BasePLer
from mmpl.structures import ClsDataSample
from .base import BaseClassifier
import lightning.pytorch as pl
import torch.nn.functional as F


@MODELS.register_module()
class MMDetPLer(BasePLer):
    def __init__(self,
                 whole_model=None,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_hyperparameters()
        self.whole_model = MODELS.build(whole_model)

    def setup(self, stage: str) -> None:
        super().setup(stage)

    def validation_step(self, batch, batch_idx):
        data = self.whole_model.data_preprocessor(batch, False)
        batch_data_samples = self.whole_model._run_forward(data, mode='predict')  # type: ignore
        # preds = []
        # targets = []
        # for data_sample in batch_data_samples:
        #     result = dict()
        #     pred = data_sample.pred_instances
        #     result['boxes'] = pred['bboxes']
        #     result['scores'] = pred['scores']
        #     result['labels'] = pred['labels']
        #     if 'masks' in pred:
        #         result['masks'] = pred['masks']
        #     preds.append(result)
        #     # parse gt
        #     gt = dict()
        #     gt_data = data_sample.get('gt_instances', None)
        #     gt['boxes'] = gt_data['bboxes']
        #     gt['labels'] = gt_data['labels']
        #     if 'masks' in pred:
        #         gt['masks'] = gt_data['masks'].to_tensor(dtype=torch.bool, device=result['masks'].device)
        #     targets.append(gt)

        # self.val_evaluator.update(preds, targets)
        self.val_evaluator.update(batch, batch_data_samples)

    def training_step(self, batch, batch_idx):
        data = self.whole_model.data_preprocessor(batch, True)
        losses = self.whole_model._run_forward(data, mode='loss')  # type: ignore
        parsed_losses, log_vars = self.parse_losses(losses)
        log_vars = {f'train_{k}': v for k, v in log_vars.items()}
        log_vars['loss'] = parsed_losses
        self.log_dict(log_vars, prog_bar=True)
        return log_vars

    def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any):
        data = self.whole_model.data_preprocessor(batch, False)
        batch_data_samples = self.whole_model._run_forward(data, mode='predict')  # type: ignore
        preds = []
        targets = []
        for data_sample in batch_data_samples:
            result = dict()
            pred = data_sample.pred_instances
            result['boxes'] = pred['bboxes']
            result['scores'] = pred['scores']
            result['labels'] = pred['labels']
            if 'masks' in pred:
                result['masks'] = pred['masks']
            preds.append(result)
            # parse gt
            gt = dict()
            gt_data = data_sample.get('gt_instances', None)
            gt['boxes'] = gt_data['bboxes']
            gt['labels'] = gt_data['labels']
            if 'masks' in pred:
                gt['masks'] = gt_data['masks'].to_tensor(dtype=torch.bool, device=result['masks'].device)
            targets.append(gt)

        # self.test_evaluator.update(preds, targets)
        self.test_evaluator.update(batch, batch_data_samples)

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
        data = self.whole_model.data_preprocessor(batch, False)
        batch_data_samples = self.whole_model._run_forward(data, mode='predict')  # type: ignore
        return batch_data_samples