cloudseg / src /models /cdnetv2_module.py
caixiaoshun's picture
使用huggingface hub尝试更新
fa7be76 verified
# -*- coding: utf-8 -*-
# @Time : 2024/8/1 下午2:47
# @Author : xiaoshun
# @Email : [email protected]
# @File : cdnetv2_module.py
# @Software: PyCharm
from typing import Tuple
import torch
import src.models.base_module
class CDNetv2LitModule(src.models.base_module.BaseLitModule):
def __init__(self,**kwargs):
super().__init__(**kwargs)
def model_step(
self, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform a single model step on a batch of data.
:param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
:return: A tuple containing (in order):
- A tensor of losses.
- A tensor of predictions.
- A tensor of target labels.
"""
x, y = batch["img"], batch["ann"]
logits ,logits_aux = self.forward(x)
loss = self.hparams.criterion(logits ,logits_aux, y)
preds = torch.argmax(logits, dim=1)
return loss, preds, y