File size: 3,527 Bytes
3138612
 
 
 
 
 
 
 
 
bf9aafc
 
3138612
bf9aafc
3138612
 
 
bf9aafc
 
 
 
 
 
 
 
3138612
5a809d1
 
3138612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf9aafc
 
3138612
 
 
bf9aafc
 
 
3138612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf9aafc
3138612
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import numpy as np
from tqdm import tqdm

from torch.utils.data import DataLoader
from data.dataLoader import ImageCaptionDataset
from config.config import Config
from models.model import ImageCaptioningModel

import mlflow
import mlflow.pytorch

# TODO : Implementing Weights and Biases to for project tracking and evaluation and TODO : DVC also for data versioning

def train_model(model,dataLoader, optimizer, loss_fn):

    with mlflow.start_run():
        mlflow.log_params({
            "epochs": Config.EPOCHS,
            "batch_size": Config.BATCH_SIZE,
            "learning_rate": Config.LEARNING_RATE,
            "device": Config.DEVICE
        })

    model.gpt2_model.train()
    for epoch in tqdm(range(Config.EPOCHS)):
        print(f'Epoch {epoch + 1}/{Config.EPOCHS}')
        epoch_loss = 0
        for batch_idx, (images, captions) in tqdm(enumerate(dataLoader)):
            print(f'\rBatch {batch_idx + 1}/{len(dataLoader)} , Loss : {epoch_loss/(batch_idx+1):.4f}\t', end='')
            images = images.to(Config.DEVICE)
            captions = [caption for caption in captions]

            # extract image features
            image_features = model.extract_image_features(images)
            # print("Image Features shape:", image_features.shape)
            input_embeds, input_ids, attention_mask = model.prepare_gpt2_inputs(image_features, captions)

            # print("Input Embeds shape:", input_embeds.shape)
            # print("Input IDs shape:", input_ids.shape)
            # print("Attention Mask shape:", attention_mask.shape)
            # Match Inputs Embeds and Input Ids and Attention Masks
            assert input_embeds.shape[1] == input_ids.shape[1] == attention_mask.shape[1]

            optimizer.zero_grad()
            outputs = model.gpt2_model(inputs_embeds=input_embeds, labels=input_ids, attention_mask=attention_mask)

            loss = outputs.loss
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f'Epoch {epoch + 1}, Loss: {epoch_loss/len(dataLoader):.4f}')
        mlflow.log_metric('loss', epoch_loss/len(dataLoader), step=epoch)

    # Save the model    
    model.save('model')
    # save the artifacts
    mlflow.log_artifacts('model')
    mlflow.pytorch.log_model(model.gpt2_model, "models")

    # return model


if __name__ == '__main__':
    # Initialize dataset using the CSV file
    dataset = ImageCaptionDataset(
        caption_file=Config.DATASET_PATH + 'captions.csv',    # Path to captions CSV file
        file_path = Config.DATASET_PATH+ '/images/', # Path to images folder
    )

    # Create DataLoader for batch processing
    dataloader = DataLoader(
        dataset, 
        batch_size=Config.BATCH_SIZE, # Specify the batch size
        shuffle=True, # Shuffle the data
        num_workers=4 # Number of subprocesses for data loading
    )

    # # Iterate over the dataloader
    # for batch_idx, (images, captions) in enumerate(dataloader):
    #     print(f'Batch {batch_idx + 1}:')
    #     print(f'Images shape: {images.shape}')
    #     print(f'Captions: {captions}')
    #     # Pass 'images' and 'captions' to your model for training/validation

    # Initialize the ImageCaptioningModel
    model = ImageCaptioningModel()
    optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE)
    loss_fn = torch.nn.CrossEntropyLoss()
    mlflow.set_experiment('ImageCaptioning')
    train_model(model, dataloader, optimizer, loss_fn)