File size: 3,962 Bytes
cb7427c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image

from vocab import Vocab
from model import Decoder, Encoder
from config import Config


def generate_caption(image: torch.Tensor,
                     image_encoder: Encoder,
                     emb_layer: torch.nn.Embedding,
                     image_decoder: Decoder,
                     vocab: Vocab,
                     device: torch.device) -> list[str]:
    """ Generate caption of a single image of size (1, 3, 224, 224)

    Returns:
        list[str]: caption for given image
    """

    image = image.to(device)
    # image: (3, 224, 224)
    image = image.unsqueeze(0)
    # image: (1, 3, 224, 224)

    features = image_encoder.forward(image)
    # features: (1, IMAGE_EMB_DIM)
    features = features.to(device)
    features = features.unsqueeze(0)
    # features: (1, 1, IMAGE_EMB_DIM)

    hidden = image_decoder.hidden_state_0
    cell = image_decoder.cell_state_0
    # hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)

    sentence = []

    # start with '<sos>' as first word
    previous_word = vocab.index2word[vocab.SOS]

    MAX_LENGTH = 20

    for i in range(MAX_LENGTH):

        input_word_id = vocab.word_to_index(previous_word)
        input_word_tensor = torch.tensor([input_word_id]).unsqueeze(0)
        # input_word_tensor : (1, 1)

        input_word_tensor = input_word_tensor.to(device)
        lstm_input = emb_layer.forward(input_word_tensor)
        # lstm_input : (1, 1, WORD_EMB_DIM)

        next_word_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
        # next_word_pred : (1, 1, VOCAB_SIZE)

        next_word_pred = next_word_pred[0, 0, :]
        # next_word_pred : (VOCAB_SIZE)

        next_word_pred = torch.argmax(next_word_pred)
        next_word_pred = vocab.index_to_word(int(next_word_pred.item()))

        # stop if we predict '<eos>'
        if next_word_pred == vocab.index2word[vocab.EOS]:
            break

        sentence.append(next_word_pred)
        previous_word = next_word_pred

    return sentence


def main_caption(image):

    config = Config()

    vocab = Vocab()
    vocab.load_vocab(config.VOCAB_FILE)

    image = Image.fromarray(image.astype('uint8'), 'RGB')

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    image = transform(image)

    image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM,
                            device=config.DEVICE)
    emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE,
                                   embedding_dim=config.WORD_EMB_DIM,
                                   padding_idx=vocab.PADDING_INDEX)
    image_decoder = Decoder(image_emb_dim=config.IMAGE_EMB_DIM,
                            word_emb_dim=config.WORD_EMB_DIM,
                            hidden_dim=config.HIDDEN_DIM,
                            num_layers=config.NUM_LAYER,
                            vocab_size=config.VOCAB_SIZE,
                            device=config.DEVICE)

    emb_layer.eval()
    image_encoder.eval()
    image_decoder.eval()

    emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE))
    image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE))
    image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE))

    emb_layer = emb_layer.to(config.DEVICE)
    image_encoder = image_encoder.to(config.DEVICE)
    image_decoder = image_decoder.to(config.DEVICE)
    image = image.to(config.DEVICE)

    sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE)
    description = ' '.join(word for word in sentence)

    return description