Spaces:
Runtime error
Runtime error
File size: 2,123 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import os
from typing import Any
import mmengine
import torch
import torch.nn as nn
from einops import rearrange
from mmdet.models.utils import samplelist_boxtype2tensor
from mmdet.structures import SampleList
from mmdet.utils import InstanceList
from mmpl.registry import MODELS
from ..builder import build_backbone, build_loss, build_neck, build_head
from .base_pler import BasePLer
from mmpl.structures import ClsDataSample
from .base import BaseClassifier
import lightning.pytorch as pl
import torch.nn.functional as F
@MODELS.register_module()
class MMClsPLer(BasePLer):
def __init__(self,
whole_model=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
self.whole_model = MODELS.build(whole_model)
def setup(self, stage: str) -> None:
super().setup(stage)
def validation_step(self, batch, batch_idx):
data = self.whole_model.data_preprocessor(batch, False)
batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore
pred_label = torch.cat([data_sample.pred_label for data_sample in batch_data_samples])
gt_label = torch.cat([data_sample.gt_label for data_sample in batch_data_samples])
self.val_evaluator.update(pred_label, gt_label)
# self.val_evaluator.update(batch, batch_data_samples)
def training_step(self, batch, batch_idx):
data = self.whole_model.data_preprocessor(batch, True)
losses = self.whole_model._run_forward(data, mode='loss') # type: ignore
parsed_losses, log_vars = self.parse_losses(losses)
log_vars = {f'train_{k}': v for k, v in log_vars.items()}
log_vars['loss'] = parsed_losses
self.log_dict(log_vars, prog_bar=True)
return log_vars
def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any):
data = self.whole_model.data_preprocessor(batch, False)
batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore
self.test_evaluator.update(batch, batch_data_samples)
|