VenkateshRoshan
inference code added
bf9aafc
from PIL import Image
from models.model import ImageCaptioningModel
from torchvision import transforms
import torch
import torch
from transformers import ViTModel, ViTFeatureExtractor, GPT2LMHeadModel, GPT2Tokenizer
from PIL import Image
from config.config import Config
class ImageCaptioningInference:
def __init__(self, model):
self.model = model
self.device = Config.DEVICE
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def infer_image(self, image):
# Load and preprocess the image
# image = Image.open(image_path)
image = self.transform(image).unsqueeze(0).to(self.device)
# Extract image features
image_features = self.model.extract_image_features(image)
# Generate caption
caption = self.generate_caption(image_features)
return caption
def generate_caption(self, image_features, num_beams=3, max_length=50):
# Prepare the image features for input
image_features = image_features.unsqueeze(1) # [batch_size, 1, hidden_size]
# Generate caption using beam search
output = self.model.gpt2_model.generate(
inputs_embeds=image_features,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
pad_token_id=self.model.tokenizer.eos_token_id,
bos_token_id=self.model.tokenizer.bos_token_id,
eos_token_id=self.model.tokenizer.eos_token_id
)
# Decode the generated caption
caption = self.model.tokenizer.decode(output[0], skip_special_tokens=True)
return caption
if __name__ == "__main__":
# Path to the saved model directory
model_dir = 'model'
# Initialize inference class
model = ImageCaptioningModel()
model.load(model_dir)
inference_model = ImageCaptioningInference(model)
# Path to the input image
image_path = 'test_img.jpg'
image = Image.open(image_path)
# Perform inference and print the generated caption
caption = inference_model.infer_image(image)
print("Generated Caption:", caption)