RSPrompter / mmpl /models /pler /mmseg_pler.py
KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
2.13 kB
import os
from typing import Any
import mmengine
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from mmpl.registry import MODELS
from ..builder import build_backbone, build_loss, build_neck, build_head
from .base_pler import BasePLer
from mmpl.structures import ClsDataSample
from .base import BaseClassifier
import lightning.pytorch as pl
import torch.nn.functional as F
@MODELS.register_module()
class MMSegPLer(BasePLer):
def __init__(self,
whole_model=None,
train_cfg=None,
test_cfg=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
self.whole_model = MODELS.build(whole_model)
def setup(self, stage: str) -> None:
pass
def init_weights(self):
import ipdb; ipdb.set_trace()
pass
def training_step(self, batch, batch_idx):
data = self.whole_model.data_preprocessor(batch, True)
losses = self.whole_model._run_forward(data, mode='loss') # type: ignore
parsed_losses, log_vars = self.parse_losses(losses)
log_vars = {f'train_{k}': v for k, v in log_vars.items()}
log_vars['loss'] = parsed_losses
self.log_dict(log_vars, prog_bar=True)
return log_vars
# return torch.tensor(0.0, requires_grad=True, device=self.device)
def validation_step(self, batch, batch_idx):
data = self.whole_model.data_preprocessor(batch, False)
data_samples = self.whole_model._run_forward(data, mode='predict')
pred = [data_sample.pred_sem_seg.data for data_sample in data_samples]
label = [data_sample.gt_sem_seg.data for data_sample in data_samples]
pred = torch.cat(pred, dim=0)
label = torch.cat(label, dim=0)
self.val_evaluator.update(pred, label)
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
data = self.whole_model.data_preprocessor(batch, False)
data_samples = self.whole_model._run_forward(data, mode='predict')
return data_samples