File size: 2,858 Bytes
3138612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import numpy as np
from tqdm import tqdm

from torch.utils.data import DataLoader
from data.dataLoader import ImageCaptionDataset
from config.config import Config
from models.model import ImageCaptioningModel

from torchsummary import summary



def train_model(model,dataLoader, optimizer, loss_fn):

    model.gpt2_model.train()
    for epoch in range(Config.EPOCHS):
        epoch_loss = 0
        for batch_idx, (images, captions) in tqdm(enumerate(dataLoader)):
            print(f'\rBatch {batch_idx + 1}/{len(dataLoader)} , Loss : {epoch_loss/(batch_idx+1):.4f}\t', end='')
            images = images.to(Config.DEVICE)
            captions = [caption for caption in captions]

            # extract image features
            image_features = model.extract_image_features(images)
            # print("Image Features shape:", image_features.shape)
            input_embeds, input_ids, attention_mask = model.prepare_gpt2_inputs(image_features, captions)

            # print("Input Embeds shape:", input_embeds.shape)
            # print("Input IDs shape:", input_ids.shape)
            # print("Attention Mask shape:", attention_mask.shape)
            # Match Inputs Embeds and Input Ids and Attention Masks
            assert input_embeds.shape[1] == input_ids.shape[1] == attention_mask.shape[1]

            optimizer.zero_grad()
            outputs = model.gpt2_model(inputs_embeds=input_embeds, labels=input_ids, attention_mask=attention_mask)

            loss = outputs.loss
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}')

    # Save the model    
    model.save('model')

    # return model


if __name__ == '__main__':
    # Initialize dataset using the CSV file
    dataset = ImageCaptionDataset(
        caption_file=Config.DATASET_PATH + 'captions.csv',    # Path to captions CSV file
        file_path = Config.DATASET_PATH+ '/images/', # Path to images folder
    )

    # Create DataLoader for batch processing
    dataloader = DataLoader(
        dataset, 
        batch_size=Config.BATCH_SIZE, # Specify the batch size
        shuffle=True, # Shuffle the data
        num_workers=4 # Number of subprocesses for data loading
    )

    # # Iterate over the dataloader
    # for batch_idx, (images, captions) in enumerate(dataloader):
    #     print(f'Batch {batch_idx + 1}:')
    #     print(f'Images shape: {images.shape}')
    #     print(f'Captions: {captions}')
    #     # Pass 'images' and 'captions' to your model for training/validation

    # Initialize the ImageCaptioningModel
    model = ImageCaptioningModel()
    optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE)
    loss_fn = torch.nn.CrossEntropyLoss()
    train_model(model, dataloader, optimizer, loss_fn)