File size: 3,493 Bytes
a578142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from tqdm import tqdm
import torch.nn as nn
from src import logger
from typing import *
import warnings
warnings.filterwarnings('ignore')
def model_fit(
    epochs: int,
    model: nn.Module,
    device: Any,
    train_loader: Any,
    valid_loader: Any,
    criterion: nn.Module,
    optimizer: nn.Module,
    PATH: str
):
    """

     Args:
         epochs (int): # of epochs
         model (nn.Module): model for training
         device (Union[int, str]): number or name of device
         train_loader (Any): pytorch loader for trainset
         valid_loader (Any): pytorch loader for testset
         criterion (nn.Module): loss critiria
         optimizer (nn.Module): optimizer for model training
         path (str): path for saving model


    Example:
    >>>     train(
    >>>     epochs=25,
    >>>     model=model,
    >>>     device=0/'cuda'/'cpu',
    >>>     train_loader=train_loader,
    >>>     valid_loader=valid_loader,
    >>>     criterion=fn_loss,
    >>>     optimizer=optimizer)
    """


    best_DICESCORE = 0
    model.to(device)
    summary = {
        'train_loss' : [],
        'train_dice' : [],
        'valid_loss' : [],
        'valid_dice' : []
    }
    for epoch in range(epochs):
        logger.info(f"EPOCH {epoch}/{epochs}")
        train_Loss = 0
        train_Dicescore = 0
        model.train()
        for x, y in tqdm(train_loader):
            x = x.type(torch.FloatTensor).to(device)
            y = y.type(torch.FloatTensor).to(device)
            
            predict = model(x)
            loss =  criterion(predict, y)
            train_Loss += loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            predict = torch.sigmoid(predict)
            predict = (predict > 0.5).float() 
            
            dice_score = (2 * (y*predict).sum() + 1e-8)/((y+predict).sum() + 1e-8)

            try:
                train_Dicescore += dice_score.cpu().item()
            except:
                train_Dicescore += dice_score
            
        train_Loss /= len(train_loader)
        train_Dicescore /= len(train_loader)
    


        with torch.no_grad():
            val_Loss = 0
            val_Dicescore = 0
            model.eval()   
            for x, y in tqdm(valid_loader):
                x = x.type(torch.FloatTensor).to(device)
                y = y.type(torch.FloatTensor).to(device)

                predict = model(x)
                loss = criterion(predict, y)
                val_Loss += loss.item()
                
                predict = torch.sigmoid(predict)
                predict = (predict > 0.5).float() 

                dice_score = (2 * (y*predict).sum() + 1e-8)/((y+predict).sum() + 1e-8)
                try:
                    val_Dicescore += dice_score.cpu().item()
                except:
                    val_Dicescore += dice_score

            val_Loss /= len(valid_loader)
            val_Dicescore /= len(valid_loader)


        logger.info(f"Loss: {train_Loss}  - Dice Score: {train_Dicescore} - Validation Loss: {val_Loss} - Validation Dice Score: {val_Dicescore}")

        if val_Dicescore > best_DICESCORE:
            best_DICESCORE  = val_Dicescore
            torch.save(model, PATH)

        summary['train_loss'] = train_Loss
        summary['train_dice'] = train_Dicescore
        summary['valid_loss'] = val_Loss 
        summary['valid_dice'] = val_Dicescore


    return summary