File size: 2,098 Bytes
a578142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn as nn
from tqdm import tqdm
from typing import *
from src import logger

def predict_mask(
        data: Any,
        device: Any,
        model: nn.Module,
        inference: bool,
        valid_loader=None,
        criterion=None,
):
    """
    predicts mask for the image
    Args:
        data (Any): image data for predicting
        model (nn.Module): model for training
        device (0/'cud'/'cpu'/Any): name of device
        inference (bool): Whether to evaluate or predict
        valid_Loader (nn.Module): test loader for training
        criterion (nn.Module): loss criteria

    Example:
    >>>     train(
    >>>     data = torch.FloatTensor,
    >>>     model=model,
    >>>     device=0/'cuda'/'cpu'
    >>>     ingerence=0
    >>>     valid_loader= test_loader
    >>>     criterion= fn_loss
    """

    if inference:

        with torch.no_grad():
            image = data.type(torch.FloatTensor).to(device)
            model = model.to(device)
            pred = model(image)
            pred = torch.sigmoid(pred)
            mask = (pred > 0.6).float()
            
            return mask.cpu().detach()
    else:
        with torch.no_grad():
            val_Loss = 0
            val_Dicescore = 0
            model.eval()   
            for x, y in tqdm(valid_loader):
                x = x.type(torch.cuda.FloatTensor).to(device)
                y = y.type(torch.cuda.FloatTensor).to(device)

                predict = model(x)
                loss = criterion(predict, y)
                val_Loss += loss.item()
                
                predict = torch.sigmoid(predict)
                predict = (predict > 0.5).float() 

                dice_score = (2 * (y*predict).sum() + 1e-8)/((y+predict).sum() + 1e-8)
                try:
                    val_Dicescore += dice_score.cpu().item()
                except:
                    val_Dicescore += dice_score

            val_Loss /= len(valid_loader)
            val_Dicescore /= len(valid_loader)

        logger.info(f"Test Loss: {val_Loss}  - Dice Score: {val_Dicescore}")