Spaces:
Sleeping
Sleeping
from PIL import Image | |
from models.model import ImageCaptioningModel | |
from torchvision import transforms | |
import torch | |
import torch | |
from transformers import ViTModel, ViTFeatureExtractor, GPT2LMHeadModel, GPT2Tokenizer | |
from PIL import Image | |
from config.config import Config | |
class ImageCaptioningInference: | |
def __init__(self, model): | |
self.model = model | |
self.device = Config.DEVICE | |
self.transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
def infer_image(self, image): | |
# Load and preprocess the image | |
# image = Image.open(image_path) | |
image = self.transform(image).unsqueeze(0).to(self.device) | |
# Extract image features | |
image_features = self.model.extract_image_features(image) | |
# Generate caption | |
caption = self.generate_caption(image_features) | |
return caption | |
def generate_caption(self, image_features, num_beams=3, max_length=50): | |
# Prepare the image features for input | |
image_features = image_features.unsqueeze(1) # [batch_size, 1, hidden_size] | |
# Generate caption using beam search | |
output = self.model.gpt2_model.generate( | |
inputs_embeds=image_features, | |
max_length=max_length, | |
num_beams=num_beams, | |
early_stopping=True, | |
pad_token_id=self.model.tokenizer.eos_token_id, | |
bos_token_id=self.model.tokenizer.bos_token_id, | |
eos_token_id=self.model.tokenizer.eos_token_id | |
) | |
# Decode the generated caption | |
caption = self.model.tokenizer.decode(output[0], skip_special_tokens=True) | |
return caption | |
if __name__ == "__main__": | |
# Path to the saved model directory | |
model_dir = 'model' | |
# Initialize inference class | |
model = ImageCaptioningModel() | |
model.load(model_dir) | |
inference_model = ImageCaptioningInference(model) | |
# Path to the input image | |
image_path = 'test_img.jpg' | |
image = Image.open(image_path) | |
# Perform inference and print the generated caption | |
caption = inference_model.infer_image(image) | |
print("Generated Caption:", caption) | |