Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from torch.utils.data import DataLoader | |
from data.dataLoader import ImageCaptionDataset | |
from config.config import Config | |
from models.model import ImageCaptioningModel | |
import mlflow | |
import mlflow.pytorch | |
# TODO : Implementing Weights and Biases to for project tracking and evaluation and TODO : DVC also for data versioning | |
def train_model(model,dataLoader, optimizer, loss_fn): | |
with mlflow.start_run(): | |
mlflow.log_params({ | |
"epochs": Config.EPOCHS, | |
"batch_size": Config.BATCH_SIZE, | |
"learning_rate": Config.LEARNING_RATE, | |
"device": Config.DEVICE | |
}) | |
model.gpt2_model.train() | |
for epoch in tqdm(range(Config.EPOCHS)): | |
print(f'Epoch {epoch + 1}/{Config.EPOCHS}') | |
epoch_loss = 0 | |
for batch_idx, (images, captions) in tqdm(enumerate(dataLoader)): | |
print(f'\rBatch {batch_idx + 1}/{len(dataLoader)} , Loss : {epoch_loss/(batch_idx+1):.4f}\t', end='') | |
images = images.to(Config.DEVICE) | |
captions = [caption for caption in captions] | |
# extract image features | |
image_features = model.extract_image_features(images) | |
# print("Image Features shape:", image_features.shape) | |
input_embeds, input_ids, attention_mask = model.prepare_gpt2_inputs(image_features, captions) | |
# print("Input Embeds shape:", input_embeds.shape) | |
# print("Input IDs shape:", input_ids.shape) | |
# print("Attention Mask shape:", attention_mask.shape) | |
# Match Inputs Embeds and Input Ids and Attention Masks | |
assert input_embeds.shape[1] == input_ids.shape[1] == attention_mask.shape[1] | |
optimizer.zero_grad() | |
outputs = model.gpt2_model(inputs_embeds=input_embeds, labels=input_ids, attention_mask=attention_mask) | |
loss = outputs.loss | |
loss.backward() | |
optimizer.step() | |
epoch_loss += loss.item() | |
print(f'Epoch {epoch + 1}, Loss: {epoch_loss/len(dataLoader):.4f}') | |
mlflow.log_metric('loss', epoch_loss/len(dataLoader), step=epoch) | |
# Save the model | |
model.save('model') | |
# save the artifacts | |
mlflow.log_artifacts('model') | |
mlflow.pytorch.log_model(model.gpt2_model, "models") | |
# return model | |
if __name__ == '__main__': | |
# Initialize dataset using the CSV file | |
dataset = ImageCaptionDataset( | |
caption_file=Config.DATASET_PATH + 'captions.csv', # Path to captions CSV file | |
file_path = Config.DATASET_PATH+ '/images/', # Path to images folder | |
) | |
# Create DataLoader for batch processing | |
dataloader = DataLoader( | |
dataset, | |
batch_size=Config.BATCH_SIZE, # Specify the batch size | |
shuffle=True, # Shuffle the data | |
num_workers=4 # Number of subprocesses for data loading | |
) | |
# # Iterate over the dataloader | |
# for batch_idx, (images, captions) in enumerate(dataloader): | |
# print(f'Batch {batch_idx + 1}:') | |
# print(f'Images shape: {images.shape}') | |
# print(f'Captions: {captions}') | |
# # Pass 'images' and 'captions' to your model for training/validation | |
# Initialize the ImageCaptioningModel | |
model = ImageCaptioningModel() | |
optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE) | |
loss_fn = torch.nn.CrossEntropyLoss() | |
mlflow.set_experiment('ImageCaptioning') | |
train_model(model, dataloader, optimizer, loss_fn) | |