File size: 4,333 Bytes
f6a67e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import print_function
from __future__ import absolute_import
__author__ = 'Tony Beltramelli - www.tonybeltramelli.com'

from .Vocabulary import *
# from .BeamSearch import *
from .Utils import *


class Sampler:
    def __init__(self, voc_path, input_shape, output_size, context_length):
        self.voc = Vocabulary()
        self.voc.retrieve(voc_path)

        self.input_shape = input_shape
        self.output_size = output_size

        print("Vocabulary size: {}".format(self.voc.size))
        print("Input shape: {}".format(self.input_shape))
        print("Output size: {}".format(self.output_size))

        self.context_length = context_length

    def predict_greedy(self, model, input_img, require_sparse_label=True, sequence_length=150, verbose=False):
        current_context = [self.voc.vocabulary[PLACEHOLDER]] * (self.context_length - 1)
        current_context.append(self.voc.vocabulary[START_TOKEN])
        if require_sparse_label:
            current_context = Utils.sparsify(current_context, self.output_size)

        predictions = START_TOKEN
        out_probas = []

        for i in range(0, sequence_length):
            if verbose:
                print("predicting {}/{}...".format(i, sequence_length))

            probas = model.predict(input_img, np.array([current_context]))
            prediction = np.argmax(probas)
            out_probas.append(probas)

            new_context = []
            for j in range(1, self.context_length):
                new_context.append(current_context[j])

            if require_sparse_label:
                sparse_label = np.zeros(self.output_size)
                sparse_label[prediction] = 1
                new_context.append(sparse_label)
            else:
                new_context.append(prediction)

            current_context = new_context

            predictions += self.voc.token_lookup[prediction]

            if self.voc.token_lookup[prediction] == END_TOKEN:
                break

        return predictions, out_probas

    # def recursive_beam_search(self, model, input_img, current_context, beam, current_node, sequence_length):
    #     probas = model.predict(input_img, np.array([current_context]))

    #     predictions = []
    #     for i in range(0, len(probas)):
    #         predictions.append((i, probas[i], probas))

    #     nodes = []
    #     for i in range(0, len(predictions)):
    #         prediction = predictions[i][0]
    #         score = predictions[i][1]
    #         output_probas = predictions[i][2]
    #         nodes.append(Node(prediction, score, output_probas))

    #     beam.add_nodes(current_node, nodes)

    #     if beam.is_valid():
    #         beam.prune_leaves()
    #         if sequence_length == 1 or self.voc.token_lookup[beam.root.max_child().key] == END_TOKEN:
    #             return

    #         for node in beam.get_leaves():
    #             prediction = node.key

    #             new_context = []
    #             for j in range(1, self.context_length):
    #                 new_context.append(current_context[j])
    #             sparse_label = np.zeros(self.output_size)
    #             sparse_label[prediction] = 1
    #             new_context.append(sparse_label)

    #             self.recursive_beam_search(model, input_img, new_context, beam, node, sequence_length - 1)

    # def predict_beam_search(self, model, input_img, beam_width=3, require_sparse_label=True, sequence_length=150):
    #     predictions = START_TOKEN
    #     out_probas = []

    #     current_context = [self.voc.vocabulary[PLACEHOLDER]] * (self.context_length - 1)
    #     current_context.append(self.voc.vocabulary[START_TOKEN])
    #     if require_sparse_label:
    #         current_context = Utils.sparsify(current_context, self.output_size)

    #     beam = BeamSearch(beam_width=beam_width)

    #     self.recursive_beam_search(model, input_img, current_context, beam, beam.root, sequence_length)

    #     predicted_sequence, probas_sequence = beam.search()

    #     for k in range(0, len(predicted_sequence)):
    #         prediction = predicted_sequence[k]
    #         probas = probas_sequence[k]
    #         out_probas.append(probas)

    #         predictions += self.voc.token_lookup[prediction]

    #     return predictions, out_probas