Spaces:
Runtime error
Runtime error
import cv2 | |
import numpy as np | |
import gradio as gr | |
# import os | |
# os.chdir('modeling') | |
import tensorflow as tf, tf_keras | |
import tensorflow_hub as hub | |
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM | |
from official.projects.movinet.modeling import movinet | |
from official.projects.movinet.modeling import movinet_model_a2_modified as movinet_model_modified | |
movinet_path = 'modeling/movinet_checkpoints_a2_epoch9' | |
movinet_model = tf_keras.models.load_model(movinet_path) | |
movinet_model.trainable = False | |
tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
t5_model = TFAutoModelForSeq2SeqLM.from_pretrained("deanna-emery/ASL_t5_word_epoch15_1204") | |
t5_model.trainable = False | |
def crop_center_square(frame): | |
y, x = frame.shape[0:2] | |
if x > y: | |
start_x = (x-y)/2 | |
end_x = start_x + y | |
start_x = int(start_x) | |
end_x = int(end_x) | |
return frame[:, int(start_x):int(end_x)] | |
else: | |
return frame | |
def preprocess(filename, max_frames=0, resize=(224,224)): | |
video_capture = cv2.VideoCapture(filename) | |
frames = [] | |
try: | |
while video_capture.isOpened(): | |
ret, frame = video_capture.read() | |
if not ret: | |
break | |
frame = crop_center_square(frame) | |
frame = cv2.resize(frame, resize) | |
frame = frame[:, :, [2, 1, 0]] | |
frames.append(frame) | |
if len(frames) == max_frames: | |
break | |
finally: | |
video_capture.release() | |
video = np.array(frames) / 255.0 | |
video = np.expand_dims(video, axis=0) | |
return video | |
def translate(video_file): | |
video = preprocess(video_file, max_frames=0, resize=(224,224)) | |
embeddings = movinet_model(video)['vid_embedding'] | |
tokens = t5_model.generate(inputs_embeds = embeddings, | |
max_new_tokens=128, | |
temperature=0.1, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
top_k=80, | |
top_p=0.90, | |
) | |
translation = tokenizer.batch_decode(tokens, skip_special_tokens=True) | |
# Return dict {label:pred} | |
return {"translation":translation} | |
# Gradio App config | |
title = "ASL Translation (MoViNet + T5)" | |
examples = [ | |
['videos/no.mp4'], | |
['videos/all.mp4'], | |
['videos/before.mp4'], | |
['videos/blue.mp4'], | |
['videos/white.mp4'], | |
['videos/accident2.mp4'] | |
] | |
# Gradio App interface | |
gr.Interface( fn=translate, | |
inputs=[gr.inputs.Video(label="Video (*.mp4)")], | |
outputs=[gr.outputs.Label(label='Translation')], | |
allow_flagging="never", | |
title=title, | |
examples=examples).launch() | |