import torch import numpy as np from tqdm import tqdm from torch.utils.data import DataLoader from data.dataLoader import ImageCaptionDataset from config.config import Config from models.model import ImageCaptioningModel from torchsummary import summary def train_model(model,dataLoader, optimizer, loss_fn): model.gpt2_model.train() for epoch in range(Config.EPOCHS): epoch_loss = 0 for batch_idx, (images, captions) in tqdm(enumerate(dataLoader)): print(f'\rBatch {batch_idx + 1}/{len(dataLoader)} , Loss : {epoch_loss/(batch_idx+1):.4f}\t', end='') images = images.to(Config.DEVICE) captions = [caption for caption in captions] # extract image features image_features = model.extract_image_features(images) # print("Image Features shape:", image_features.shape) input_embeds, input_ids, attention_mask = model.prepare_gpt2_inputs(image_features, captions) # print("Input Embeds shape:", input_embeds.shape) # print("Input IDs shape:", input_ids.shape) # print("Attention Mask shape:", attention_mask.shape) # Match Inputs Embeds and Input Ids and Attention Masks assert input_embeds.shape[1] == input_ids.shape[1] == attention_mask.shape[1] optimizer.zero_grad() outputs = model.gpt2_model(inputs_embeds=input_embeds, labels=input_ids, attention_mask=attention_mask) loss = outputs.loss loss.backward() optimizer.step() epoch_loss += loss.item() print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}') # Save the model model.save('model') # return model if __name__ == '__main__': # Initialize dataset using the CSV file dataset = ImageCaptionDataset( caption_file=Config.DATASET_PATH + 'captions.csv', # Path to captions CSV file file_path = Config.DATASET_PATH+ '/images/', # Path to images folder ) # Create DataLoader for batch processing dataloader = DataLoader( dataset, batch_size=Config.BATCH_SIZE, # Specify the batch size shuffle=True, # Shuffle the data num_workers=4 # Number of subprocesses for data loading ) # # Iterate over the dataloader # for batch_idx, (images, captions) in enumerate(dataloader): # print(f'Batch {batch_idx + 1}:') # print(f'Images shape: {images.shape}') # print(f'Captions: {captions}') # # Pass 'images' and 'captions' to your model for training/validation # Initialize the ImageCaptioningModel model = ImageCaptioningModel() optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE) loss_fn = torch.nn.CrossEntropyLoss() train_model(model, dataloader, optimizer, loss_fn)