Spaces:
Runtime error
Runtime error
import re | |
import string | |
import gradio as gr | |
import tensorflow as tf | |
from load_model import build | |
def custom_standardization(s): | |
s = tf.strings.lower(s) | |
s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '') | |
s = tf.strings.join(['[START]', s, '[END]'], separator=' ') | |
return s | |
model = build() | |
def single_img_transcribe(image, temperature=1): | |
initial = model.word_to_index([['[START]']]) # (batch, sequence) | |
img_features = model.feature_extractor(image[tf.newaxis, ...]) | |
tokens = initial # (batch, sequence) | |
for n in range(50): | |
preds = model((img_features, tokens)).numpy() # (batch, sequence, vocab) | |
preds = preds[:,-1, :] #(batch, vocab) | |
if temperature==0: | |
next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1) | |
else: | |
next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1) | |
tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) | |
if next[0] == model.word_to_index('[END]'): | |
break | |
words = model.index_to_word(tokens[0, 1:-1]) | |
result = tf.strings.reduce_join(words, axis=-1, separator=' ') | |
return result.numpy().decode() | |
def img_transcribes(image): | |
result = [] | |
for t in [0,0.5,1]: | |
result.append(single_img_transcribe(image, t)) | |
return result | |
gr.Interface(fn=img_transcribes, | |
inputs=gr.Image(type="pil"), | |
outputs=["text","text","text"] | |
).launch() | |