File size: 7,160 Bytes
97daae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import torch
import wandb
from tqdm.auto import tqdm

API_KEY = "881252af31786a1cf813449b9b4124955f54703e"


def train_loop(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer, accuracy_fn, device: torch.device):

    """
    Function for Model Training

    Args:
        model: A pytorch model you want to train
        dataloader: A dataloader for intance for model training
        loss_fn: A loss function for calculate model loss
        accuracy_fn: A Accuracy function to check how model is accurate
        device: A device on which model run i.e.: "cuda" or "cpu"

    Return:
        list of train loss and accuracy and also model weights

    Example usage:
        train_loop(model = mymodel, dataloader = train_dataloader, loss_fn = loss_fn, 
                    accuracy_fn = accuracy_fn, device = device)
    """

    train_loss, train_acc = 0, 0 

    model.train()

    for batch, (x_train, y_train) in enumerate(dataloader):
        x_train, y_train = x_train.to(device), y_train.to(device)

        # 1. Forwrad Pass
        logits = model(x_train)

        # 2. Loss
        loss = loss_fn(logits, y_train)

        # 3. Gradzero step
        optimizer.zero_grad()

        # 4. Backward
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

        acc = accuracy_fn(torch.argmax(logits, dim = 1), y_train)

        train_acc += acc
        train_loss += loss

    train_loss /= len(dataloader)
    train_acc /= len(dataloader)

    return model, train_loss, train_acc


def test_loop(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module,
              accuracy_fn, device: torch.device):

    """
    A Funtion to test the model after traininig

    Args:
        model: A model which you want to test on test intance
        dataloader: A dataloader intance on which you test model
        loss_fn: A loss function to calculate the model loss
        accuracy_fn : A accuracy function to check model accuracy on dataloader intance
        device: A device on whic you want to run model i.e.: "cuda" or "cpu"

    Return:
        A list of test loss and Accuracy

    Example Usage:
        test_loop(model = mymodel, dataloader = test_datloader, loss_fn = loss_fn,
                  accuracy_fn = accuracy+fn, device = device)
    """

    test_loss, test_acc = 0, 0

    model.eval()
    with torch.inference_mode():
        for x_test, y_test in dataloader:
            x_test, y_test = x_test.to(device), y_test.to(device)

            # 1. Forward Pass
            logits = model(x_test)

            # 2. Loss
            test_loss += loss_fn(logits, y_test)

            test_acc += accuracy_fn(torch.argmax(logits, dim = 1), y_test)

        test_acc /= len(dataloader)
        test_loss /= len(dataloader)

    return test_loss, test_acc


def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    wandb.login(key=API_KEY)
    print("[LOG]: Login Succesfull.")
    # 🐝 Create a wandb Table to log images, labels and predictions to
    table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
    for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
        table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table":table}, commit=False)



def validation(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module,
               accuracy_fn, log_images: bool, device: torch.device, batch_idx=0):
    """
    A Function for hyperparameter tuning

    Args:
        model: A model the you want tune its hyperparameter
        dataloder: Adataloader intance for hyperparameter tuning
        loss_fn: A loss funtion to calcualte model loss
        Accuracy_fn: a accuracy function to calcultae accuracy for model perforamnce
        device: A device on which model run i.e.: "cuda" or "cpu"

    Return:
        A list of accuracy and loss

    Example usage:
        validation(model = mymodel, dataloader = valid_dataloader, loss_fn = loss_fn,
                   accuracy_fn = accuracy_fn, device = device)
    """

    val_loss, val_acc = 0, 0

    model.eval()
    with torch.inference_mode():
        for i, (x_val, y_val) in enumerate(dataloader):
            x_val, y_val = x_val.to(device), y_val.to(device)

            logits = model(x_val)

            val_loss += loss_fn(logits, y_val)

            val_acc += accuracy_fn(torch.argmax(logits, dim = 1), y_val)

            # Log one batch of images to the dashboard, always same batch_idx.
            if i==batch_idx and log_images:
                log_image_table(x_val, torch.max(logits.data, 1)[0], y_val, logits.softmax(dim=1))

        val_loss /= len(dataloader)
        val_acc /= len(dataloader)

    return val_loss, val_acc

def train(model: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer, loss_fn: torch.nn.Module, accuracy_fn, epochs: int, device: torch.device):

    """
    A function to train and test the pytorch model

    Args:
        model: A model to train and test
        train_dataloader: A dataloader intance for train model
        test_dataloader: A dataloader intance for test model
        optimizer: A optimizer funtion to optimize the model
        loss_fn: A loss function to calculate model loss
        accuracy_fn: An accuracy  to calculate model performance
        epochs: number of iteration to run the loop
        device: A device on which model run i.e.: "cuda" or "cpu"

    Return:
        train model, List of train and test losses and accuracy

    Example usage:
        train(model = mymodel, train_dataloader = train_dataloader, test_dataloader = test_dataloader, optimizer = optimizer,
              loss_fn = loss_fn, acuuracy_fn = accuracy_fn, epochs = epochs, device = device)
    """

    train_losses, test_losses = [], []
    train_accs, test_accs = [], []

    for epoch in tqdm(range(epochs)):

        print(f"\nEpoch: {epoch+1}")

        train_model, train_loss, train_acc = train_loop(model = model, dataloader = train_dataloader, 
                                                        loss_fn = loss_fn, optimizer = optimizer, 
                                                        accuracy_fn = accuracy_fn, device = device)

        test_loss, test_acc = test_loop(model = model, dataloader = test_dataloader, loss_fn = loss_fn,
                                        accuracy_fn = accuracy_fn, device = device)

        print(f"Train Loss: {train_loss:.5f} | Test Loss: {test_loss:.5f} || Train Accuracy: {train_acc:.5f} | Test Accuracy: {test_acc:.5f}")

        train_losses.append(train_loss.item())
        test_losses.append(test_loss.item())
        train_accs.append(train_acc.item())
        test_accs.append(test_acc.item())

    return train_losses, test_losses, train_accs, test_accs, train_model