Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# @Time : 2024/8/1 下午2:47 | |
# @Author : xiaoshun | |
# @Email : [email protected] | |
# @File : cdnetv2_module.py | |
# @Software: PyCharm | |
from typing import Tuple | |
import torch | |
import src.models.base_module | |
class CDNetv2LitModule(src.models.base_module.BaseLitModule): | |
def __init__(self,**kwargs): | |
super().__init__(**kwargs) | |
def model_step( | |
self, batch: Tuple[torch.Tensor, torch.Tensor] | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Perform a single model step on a batch of data. | |
:param batch: A batch of data (a tuple) containing the input tensor of images and target labels. | |
:return: A tuple containing (in order): | |
- A tensor of losses. | |
- A tensor of predictions. | |
- A tensor of target labels. | |
""" | |
x, y = batch["img"], batch["ann"] | |
logits ,logits_aux = self.forward(x) | |
loss = self.hparams.criterion(logits ,logits_aux, y) | |
preds = torch.argmax(logits, dim=1) | |
return loss, preds, y | |