RSPrompter / mmpl /models /pler /mmcls_pler.py
KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
2.12 kB
import os
from typing import Any
import mmengine
import torch
import torch.nn as nn
from einops import rearrange
from mmdet.models.utils import samplelist_boxtype2tensor
from mmdet.structures import SampleList
from mmdet.utils import InstanceList
from mmpl.registry import MODELS
from ..builder import build_backbone, build_loss, build_neck, build_head
from .base_pler import BasePLer
from mmpl.structures import ClsDataSample
from .base import BaseClassifier
import lightning.pytorch as pl
import torch.nn.functional as F
@MODELS.register_module()
class MMClsPLer(BasePLer):
def __init__(self,
whole_model=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
self.whole_model = MODELS.build(whole_model)
def setup(self, stage: str) -> None:
super().setup(stage)
def validation_step(self, batch, batch_idx):
data = self.whole_model.data_preprocessor(batch, False)
batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore
pred_label = torch.cat([data_sample.pred_label for data_sample in batch_data_samples])
gt_label = torch.cat([data_sample.gt_label for data_sample in batch_data_samples])
self.val_evaluator.update(pred_label, gt_label)
# self.val_evaluator.update(batch, batch_data_samples)
def training_step(self, batch, batch_idx):
data = self.whole_model.data_preprocessor(batch, True)
losses = self.whole_model._run_forward(data, mode='loss') # type: ignore
parsed_losses, log_vars = self.parse_losses(losses)
log_vars = {f'train_{k}': v for k, v in log_vars.items()}
log_vars['loss'] = parsed_losses
self.log_dict(log_vars, prog_bar=True)
return log_vars
def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any):
data = self.whole_model.data_preprocessor(batch, False)
batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore
self.test_evaluator.update(batch, batch_data_samples)