Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -599,76 +599,76 @@ def load_and_preprocess_image(img):
|
|
599 |
return np.expand_dims(img, axis=0)
|
600 |
|
601 |
|
602 |
-
def generate_caption_coca(image):
|
603 |
-
|
604 |
|
605 |
-
|
606 |
-
|
607 |
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
|
613 |
-
|
614 |
-
|
615 |
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
|
631 |
-
|
632 |
|
633 |
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
|
638 |
-
|
639 |
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
|
647 |
-
|
648 |
-
|
649 |
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
|
661 |
-
|
662 |
-
|
663 |
|
664 |
-
|
665 |
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
|
670 |
-
|
671 |
-
|
672 |
|
673 |
|
674 |
def generate_caption_rnn(image):
|
|
|
599 |
return np.expand_dims(img, axis=0)
|
600 |
|
601 |
|
602 |
+
# def generate_caption_coca(image):
|
603 |
+
# img_processed = load_and_preprocess_image(image)
|
604 |
|
605 |
+
# _, cap_features = coca_model.encoder.predict(img_processed, verbose=0)
|
606 |
+
# cap_features = cap_features.astype(np.float32)
|
607 |
|
608 |
+
# start_token_id = word_index[start_token]
|
609 |
+
# end_token_id = word_index[end_token]
|
610 |
+
# sequence = [start_token_id]
|
611 |
+
# text_input = np.zeros((1, sentence_length - 1), dtype=np.float32)
|
612 |
|
613 |
+
# for t in range(sentence_length - 1):
|
614 |
+
# text_input[0, :len(sequence)] = sequence
|
615 |
|
616 |
+
# _, logits = coca_model.decoder.predict(
|
617 |
+
# [text_input, cap_features],
|
618 |
+
# verbose=0
|
619 |
+
# )
|
620 |
+
# next_token = np.argmax(logits[0, t, :])
|
621 |
|
622 |
+
# sequence.append(next_token)
|
623 |
+
# if next_token == end_token_id or len(sequence) >= (sentence_length - 1):
|
624 |
+
# break
|
625 |
|
626 |
+
# caption = " ".join(
|
627 |
+
# [index_word[token] for token in sequence
|
628 |
+
# if token not in {word_index[start_token], word_index[end_token]}]
|
629 |
+
# )
|
630 |
|
631 |
+
# return caption
|
632 |
|
633 |
|
634 |
+
def generate_caption_coca(image):
|
635 |
+
img_processed = load_and_preprocess_image(image)
|
636 |
+
_, cap_features = coca_model.encoder.predict(img_processed, verbose=0)
|
637 |
|
638 |
+
beams = [([word_index[start_token]], 0.0)]
|
639 |
|
640 |
+
for _ in range(max_length):
|
641 |
+
new_beams = []
|
642 |
+
for seq, log_prob in beams:
|
643 |
+
if seq[-1] == word_index[end_token]:
|
644 |
+
new_beams.append((seq, log_prob))
|
645 |
+
continue
|
646 |
|
647 |
+
text_input = np.zeros((1, max_length), dtype=np.int32)
|
648 |
+
text_input[0, :len(seq)] = seq
|
649 |
|
650 |
+
predictions = coca_model.decoder.predict([text_input, cap_features], verbose=0)
|
651 |
+
_, logits = predictions
|
652 |
+
logits = logits[0, len(seq)-1, :]
|
653 |
+
probs = np.exp(logits - np.max(logits))
|
654 |
+
probs /= probs.sum()
|
655 |
|
656 |
+
top_k = np.argsort(-probs)[:beam_width]
|
657 |
+
for token in top_k:
|
658 |
+
new_seq = seq + [token]
|
659 |
+
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
|
660 |
|
661 |
+
if has_repeated_ngrams(new_seq, n=2):
|
662 |
+
new_log_prob -= 0.5
|
663 |
|
664 |
+
new_beams.append((new_seq, new_log_prob))
|
665 |
|
666 |
+
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
|
667 |
+
if all(beam[0][-1] == word_index[end_token] for beam in beams):
|
668 |
+
break
|
669 |
|
670 |
+
best_seq = max(beams, key=lambda x: x[1])[0]
|
671 |
+
return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
|
672 |
|
673 |
|
674 |
def generate_caption_rnn(image):
|