VenkateshRoshan
Files updated
5a809d1
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from data.dataLoader import ImageCaptionDataset
from config.config import Config
from models.model import ImageCaptioningModel
import mlflow
import mlflow.pytorch
# TODO : Implementing Weights and Biases to for project tracking and evaluation and TODO : DVC also for data versioning
def train_model(model,dataLoader, optimizer, loss_fn):
with mlflow.start_run():
mlflow.log_params({
"epochs": Config.EPOCHS,
"batch_size": Config.BATCH_SIZE,
"learning_rate": Config.LEARNING_RATE,
"device": Config.DEVICE
})
model.gpt2_model.train()
for epoch in tqdm(range(Config.EPOCHS)):
print(f'Epoch {epoch + 1}/{Config.EPOCHS}')
epoch_loss = 0
for batch_idx, (images, captions) in tqdm(enumerate(dataLoader)):
print(f'\rBatch {batch_idx + 1}/{len(dataLoader)} , Loss : {epoch_loss/(batch_idx+1):.4f}\t', end='')
images = images.to(Config.DEVICE)
captions = [caption for caption in captions]
# extract image features
image_features = model.extract_image_features(images)
# print("Image Features shape:", image_features.shape)
input_embeds, input_ids, attention_mask = model.prepare_gpt2_inputs(image_features, captions)
# print("Input Embeds shape:", input_embeds.shape)
# print("Input IDs shape:", input_ids.shape)
# print("Attention Mask shape:", attention_mask.shape)
# Match Inputs Embeds and Input Ids and Attention Masks
assert input_embeds.shape[1] == input_ids.shape[1] == attention_mask.shape[1]
optimizer.zero_grad()
outputs = model.gpt2_model(inputs_embeds=input_embeds, labels=input_ids, attention_mask=attention_mask)
loss = outputs.loss
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {epoch_loss/len(dataLoader):.4f}')
mlflow.log_metric('loss', epoch_loss/len(dataLoader), step=epoch)
# Save the model
model.save('model')
# save the artifacts
mlflow.log_artifacts('model')
mlflow.pytorch.log_model(model.gpt2_model, "models")
# return model
if __name__ == '__main__':
# Initialize dataset using the CSV file
dataset = ImageCaptionDataset(
caption_file=Config.DATASET_PATH + 'captions.csv', # Path to captions CSV file
file_path = Config.DATASET_PATH+ '/images/', # Path to images folder
)
# Create DataLoader for batch processing
dataloader = DataLoader(
dataset,
batch_size=Config.BATCH_SIZE, # Specify the batch size
shuffle=True, # Shuffle the data
num_workers=4 # Number of subprocesses for data loading
)
# # Iterate over the dataloader
# for batch_idx, (images, captions) in enumerate(dataloader):
# print(f'Batch {batch_idx + 1}:')
# print(f'Images shape: {images.shape}')
# print(f'Captions: {captions}')
# # Pass 'images' and 'captions' to your model for training/validation
# Initialize the ImageCaptioningModel
model = ImageCaptioningModel()
optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE)
loss_fn = torch.nn.CrossEntropyLoss()
mlflow.set_experiment('ImageCaptioning')
train_model(model, dataloader, optimizer, loss_fn)