import torch import numpy as np from tqdm import tqdm from torch.utils.data import DataLoader from data.dataLoader import ImageCaptionDataset from config.config import Config from models.model import ImageCaptioningModel import mlflow import mlflow.pytorch # TODO : Implementing Weights and Biases to for project tracking and evaluation and TODO : DVC also for data versioning def train_model(model,dataLoader, optimizer, loss_fn): with mlflow.start_run(): mlflow.log_params({ "epochs": Config.EPOCHS, "batch_size": Config.BATCH_SIZE, "learning_rate": Config.LEARNING_RATE, "device": Config.DEVICE }) model.gpt2_model.train() for epoch in tqdm(range(Config.EPOCHS)): print(f'Epoch {epoch + 1}/{Config.EPOCHS}') epoch_loss = 0 for batch_idx, (images, captions) in tqdm(enumerate(dataLoader)): print(f'\rBatch {batch_idx + 1}/{len(dataLoader)} , Loss : {epoch_loss/(batch_idx+1):.4f}\t', end='') images = images.to(Config.DEVICE) captions = [caption for caption in captions] # extract image features image_features = model.extract_image_features(images) # print("Image Features shape:", image_features.shape) input_embeds, input_ids, attention_mask = model.prepare_gpt2_inputs(image_features, captions) # print("Input Embeds shape:", input_embeds.shape) # print("Input IDs shape:", input_ids.shape) # print("Attention Mask shape:", attention_mask.shape) # Match Inputs Embeds and Input Ids and Attention Masks assert input_embeds.shape[1] == input_ids.shape[1] == attention_mask.shape[1] optimizer.zero_grad() outputs = model.gpt2_model(inputs_embeds=input_embeds, labels=input_ids, attention_mask=attention_mask) loss = outputs.loss loss.backward() optimizer.step() epoch_loss += loss.item() print(f'Epoch {epoch + 1}, Loss: {epoch_loss/len(dataLoader):.4f}') mlflow.log_metric('loss', epoch_loss/len(dataLoader), step=epoch) # Save the model model.save('model') # save the artifacts mlflow.log_artifacts('model') mlflow.pytorch.log_model(model.gpt2_model, "models") # return model if __name__ == '__main__': # Initialize dataset using the CSV file dataset = ImageCaptionDataset( caption_file=Config.DATASET_PATH + 'captions.csv', # Path to captions CSV file file_path = Config.DATASET_PATH+ '/images/', # Path to images folder ) # Create DataLoader for batch processing dataloader = DataLoader( dataset, batch_size=Config.BATCH_SIZE, # Specify the batch size shuffle=True, # Shuffle the data num_workers=4 # Number of subprocesses for data loading ) # # Iterate over the dataloader # for batch_idx, (images, captions) in enumerate(dataloader): # print(f'Batch {batch_idx + 1}:') # print(f'Images shape: {images.shape}') # print(f'Captions: {captions}') # # Pass 'images' and 'captions' to your model for training/validation # Initialize the ImageCaptioningModel model = ImageCaptioningModel() optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE) loss_fn = torch.nn.CrossEntropyLoss() mlflow.set_experiment('ImageCaptioning') train_model(model, dataloader, optimizer, loss_fn)