Spaces:
Runtime error
Runtime error
import os | |
from typing import Any | |
import mmengine | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from mmdet.models.utils import samplelist_boxtype2tensor | |
from mmdet.structures import SampleList | |
from mmdet.utils import InstanceList | |
from mmpl.registry import MODELS | |
from ..builder import build_backbone, build_loss, build_neck, build_head | |
from .base_pler import BasePLer | |
from mmpl.structures import ClsDataSample | |
from .base import BaseClassifier | |
import lightning.pytorch as pl | |
import torch.nn.functional as F | |
class MMClsPLer(BasePLer): | |
def __init__(self, | |
whole_model=None, | |
*args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.save_hyperparameters() | |
self.whole_model = MODELS.build(whole_model) | |
def setup(self, stage: str) -> None: | |
super().setup(stage) | |
def validation_step(self, batch, batch_idx): | |
data = self.whole_model.data_preprocessor(batch, False) | |
batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore | |
pred_label = torch.cat([data_sample.pred_label for data_sample in batch_data_samples]) | |
gt_label = torch.cat([data_sample.gt_label for data_sample in batch_data_samples]) | |
self.val_evaluator.update(pred_label, gt_label) | |
# self.val_evaluator.update(batch, batch_data_samples) | |
def training_step(self, batch, batch_idx): | |
data = self.whole_model.data_preprocessor(batch, True) | |
losses = self.whole_model._run_forward(data, mode='loss') # type: ignore | |
parsed_losses, log_vars = self.parse_losses(losses) | |
log_vars = {f'train_{k}': v for k, v in log_vars.items()} | |
log_vars['loss'] = parsed_losses | |
self.log_dict(log_vars, prog_bar=True) | |
return log_vars | |
def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any): | |
data = self.whole_model.data_preprocessor(batch, False) | |
batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore | |
self.test_evaluator.update(batch, batch_data_samples) | |