import os from typing import Any import mmengine import torch import torch.nn as nn from einops import rearrange from mmdet.models.utils import samplelist_boxtype2tensor from mmdet.structures import SampleList from mmdet.utils import InstanceList from mmpl.registry import MODELS from ..builder import build_backbone, build_loss, build_neck, build_head from .base_pler import BasePLer from mmpl.structures import ClsDataSample from .base import BaseClassifier import lightning.pytorch as pl import torch.nn.functional as F @MODELS.register_module() class MMDetPLer(BasePLer): def __init__(self, whole_model=None, *args, **kwargs): super().__init__(*args, **kwargs) self.save_hyperparameters() self.whole_model = MODELS.build(whole_model) def setup(self, stage: str) -> None: super().setup(stage) def validation_step(self, batch, batch_idx): data = self.whole_model.data_preprocessor(batch, False) batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore # preds = [] # targets = [] # for data_sample in batch_data_samples: # result = dict() # pred = data_sample.pred_instances # result['boxes'] = pred['bboxes'] # result['scores'] = pred['scores'] # result['labels'] = pred['labels'] # if 'masks' in pred: # result['masks'] = pred['masks'] # preds.append(result) # # parse gt # gt = dict() # gt_data = data_sample.get('gt_instances', None) # gt['boxes'] = gt_data['bboxes'] # gt['labels'] = gt_data['labels'] # if 'masks' in pred: # gt['masks'] = gt_data['masks'].to_tensor(dtype=torch.bool, device=result['masks'].device) # targets.append(gt) # self.val_evaluator.update(preds, targets) self.val_evaluator.update(batch, batch_data_samples) def training_step(self, batch, batch_idx): data = self.whole_model.data_preprocessor(batch, True) losses = self.whole_model._run_forward(data, mode='loss') # type: ignore parsed_losses, log_vars = self.parse_losses(losses) log_vars = {f'train_{k}': v for k, v in log_vars.items()} log_vars['loss'] = parsed_losses self.log_dict(log_vars, prog_bar=True) return log_vars def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any): data = self.whole_model.data_preprocessor(batch, False) batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore preds = [] targets = [] for data_sample in batch_data_samples: result = dict() pred = data_sample.pred_instances result['boxes'] = pred['bboxes'] result['scores'] = pred['scores'] result['labels'] = pred['labels'] if 'masks' in pred: result['masks'] = pred['masks'] preds.append(result) # parse gt gt = dict() gt_data = data_sample.get('gt_instances', None) gt['boxes'] = gt_data['bboxes'] gt['labels'] = gt_data['labels'] if 'masks' in pred: gt['masks'] = gt_data['masks'].to_tensor(dtype=torch.bool, device=result['masks'].device) targets.append(gt) # self.test_evaluator.update(preds, targets) self.test_evaluator.update(batch, batch_data_samples) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: data = self.whole_model.data_preprocessor(batch, False) batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore return batch_data_samples