Spaces:
Runtime error
Runtime error
File size: 3,519 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import torch
from mmengine.structures import InstanceData
from typing import List, Any
from mmpl.registry import MODELS
from mmseg.utils import SampleList
from .base_pler import BasePLer
import torch.nn.functional as F
from modules.sam import sam_model_registry
@MODELS.register_module()
class SegSAMAnchorPLer(BasePLer):
def __init__(self,
backbone,
neck=None,
panoptic_head=None,
need_train_names=None,
train_cfg=None,
test_cfg=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
self.need_train_names = need_train_names
backbone_type = backbone.pop('type')
self.backbone = sam_model_registry[backbone_type](**backbone)
if neck is not None:
self.neck = MODELS.build(neck)
self.panoptic_head = MODELS.build(panoptic_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def setup(self, stage: str) -> None:
super().setup(stage)
if self.need_train_names is not None:
self._set_grad(self.need_train_names, noneed_train_names=[])
def init_weights(self):
import ipdb; ipdb.set_trace()
pass
def train(self, mode=True):
if self.need_train_names is not None:
return self._set_train_module(mode, self.need_train_names)
else:
super().train(mode)
return self
@torch.no_grad()
def extract_feat(self, batch_inputs):
feat, inter_features = self.backbone.image_encoder(batch_inputs)
return feat, inter_features
def validation_step(self, batch, batch_idx):
data = self.data_preprocessor(batch, False)
batch_inputs = data['inputs']
batch_data_samples = data['data_samples']
x = self.extract_feat(batch_inputs)
# x = (
# torch.rand(2, 256, 64, 64).to(self.device), [torch.rand(2, 64, 64, 768).to(self.device) for _ in range(12)])
results = self.panoptic_head.predict(
x, batch_data_samples, self.backbone)
self.val_evaluator.update(batch, results)
def training_step(self, batch, batch_idx):
data = self.data_preprocessor(batch, True)
batch_inputs = data['inputs']
batch_data_samples = data['data_samples']
x = self.extract_feat(batch_inputs)
# x = (torch.rand(2, 256, 64, 64).to(self.device), [torch.rand(2, 64, 64, 768).to(self.device) for _ in range(12)])
losses = self.panoptic_head.loss(x, batch_data_samples, self.backbone)
parsed_losses, log_vars = self.parse_losses(losses)
log_vars = {f'train_{k}': v for k, v in log_vars.items()}
log_vars['loss'] = parsed_losses
self.log_dict(log_vars, prog_bar=True)
return log_vars
def on_before_optimizer_step(self, optimizer) -> None:
self.log_grad(module=self.panoptic_head)
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
data = self.data_preprocessor(batch, False)
batch_inputs = data['inputs']
batch_data_samples = data['data_samples']
x = self.extract_feat(batch_inputs)
# x = (
# torch.rand(2, 256, 64, 64).to(self.device), [torch.rand(2, 64, 64, 768).to(self.device) for _ in range(12)])
results = self.panoptic_head.predict(
x, batch_data_samples, self.backbone)
return results
|