File size: 1,045 Bytes
1c3eb47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import Any

import torch
import torch.nn as nn
from mmpl.registry import MODELS
from ..builder import build_backbone, build_loss
from .base_pler import BasePLer
from mmpl.structures import ClsDataSample
from .base import BaseClassifier
import lightning.pytorch as pl
import torch.nn.functional as F


@MODELS.register_module()
class GPTPLer(BasePLer):
    def __init__(self,
                 backbone,
                 loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.backbone = build_backbone(backbone)
        self.loss = build_loss(loss)

    def training_step(self, batch, batch_idx):
        x, gt_label = batch['x'], batch['gt_label']
        outputs = self(input_ids=x, labels=gt_label)
        loss, logits = outputs['loss'], outputs['logits']
        return loss

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        return self.backbone(*args, **kwargs)

    def validation_step(self, batch, batch_idx):
        pass