Spaces:
Runtime error
Runtime error
File size: 1,045 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from typing import Any
import torch
import torch.nn as nn
from mmpl.registry import MODELS
from ..builder import build_backbone, build_loss
from .base_pler import BasePLer
from mmpl.structures import ClsDataSample
from .base import BaseClassifier
import lightning.pytorch as pl
import torch.nn.functional as F
@MODELS.register_module()
class GPTPLer(BasePLer):
def __init__(self,
backbone,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
*args, **kwargs):
super().__init__(*args, **kwargs)
self.backbone = build_backbone(backbone)
self.loss = build_loss(loss)
def training_step(self, batch, batch_idx):
x, gt_label = batch['x'], batch['gt_label']
outputs = self(input_ids=x, labels=gt_label)
loss, logits = outputs['loss'], outputs['logits']
return loss
def forward(self, *args: Any, **kwargs: Any) -> Any:
return self.backbone(*args, **kwargs)
def validation_step(self, batch, batch_idx):
pass
|