File size: 2,226 Bytes
bf9aafc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from PIL import Image
from models.model import ImageCaptioningModel
from torchvision import transforms
import torch

import torch
from transformers import ViTModel, ViTFeatureExtractor, GPT2LMHeadModel, GPT2Tokenizer
from PIL import Image
from config.config import Config

class ImageCaptioningInference:
    def __init__(self, model):
        self.model = model
        self.device = Config.DEVICE
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def infer_image(self, image):
        # Load and preprocess the image
        # image = Image.open(image_path)
        image = self.transform(image).unsqueeze(0).to(self.device)

        # Extract image features
        image_features = self.model.extract_image_features(image)

        # Generate caption
        caption = self.generate_caption(image_features)
        return caption
    
    def generate_caption(self, image_features, num_beams=3, max_length=50):
        # Prepare the image features for input
        image_features = image_features.unsqueeze(1)  # [batch_size, 1, hidden_size]
        
        # Generate caption using beam search
        output = self.model.gpt2_model.generate(
            inputs_embeds=image_features,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True,
            pad_token_id=self.model.tokenizer.eos_token_id,
            bos_token_id=self.model.tokenizer.bos_token_id,
            eos_token_id=self.model.tokenizer.eos_token_id
        )
        
        # Decode the generated caption
        caption = self.model.tokenizer.decode(output[0], skip_special_tokens=True)
        return caption

if __name__ == "__main__":
    # Path to the saved model directory
    model_dir = 'model'
    
    # Initialize inference class
    model = ImageCaptioningModel()
    model.load(model_dir)

    inference_model = ImageCaptioningInference(model)
    
    # Path to the input image
    image_path = 'test_img.jpg'

    image = Image.open(image_path)
    
    # Perform inference and print the generated caption
    caption = inference_model.infer_image(image)
    print("Generated Caption:", caption)