Spaces:
Runtime error
Runtime error
import cv2 | |
import numpy as np | |
import gradio as gr | |
import tensorflow as tf, tf_keras | |
import tensorflow_hub as hub | |
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM | |
from official.projects.movinet.modeling import movinet | |
from official.projects.movinet.modeling import movinet_model_a2_modified as movinet_model_modified | |
movinet_path = 'movinet_checkpoints_a2_epoch9' | |
movinet_model = tf_keras.models.load_model(movinet_path) | |
movinet_model.trainable = False | |
tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
t5_model = TFAutoModelForSeq2SeqLM.from_pretrained("deanna-emery/ASL_t5_movinet_sentence") | |
t5_model.trainable = False | |
def crop_center_square(frame): | |
y, x = frame.shape[0:2] | |
if x > y: | |
start_x = (x-y)/2 | |
end_x = start_x + y | |
start_x = int(start_x) | |
end_x = int(end_x) | |
return frame[:, int(start_x):int(end_x)] | |
else: | |
return frame | |
def preprocess(filename, max_frames=0, resize=(224,224)): | |
video_capture = cv2.VideoCapture(filename) | |
frames = [] | |
try: | |
while video_capture.isOpened(): | |
ret, frame = video_capture.read() | |
if not ret: | |
break | |
frame = crop_center_square(frame) | |
frame = cv2.resize(frame, resize) | |
frame = frame[:, :, [2, 1, 0]] | |
frames.append(frame) | |
if len(frames) == max_frames: | |
break | |
finally: | |
video_capture.release() | |
video = np.array(frames) / 255.0 | |
video = np.expand_dims(video, axis=0) | |
return video | |
def translate(video_file): | |
video = preprocess(video_file, max_frames=0, resize=(224,224)) | |
embeddings = movinet_model(video)['vid_embedding'] | |
tokens = t5_model.generate(inputs_embeds = embeddings, | |
max_new_tokens=128, | |
temperature=0.1, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
top_k=80, | |
top_p=0.90, | |
) | |
translation = tokenizer.batch_decode(tokens, skip_special_tokens=True) | |
return {"translation":translation} | |
# Gradio App config | |
title = "ASL Translation (MoViNet + T5)" | |
examples = [ | |
["videos/My second ASL professor's name was Will White.mp4"], | |
['videos/You are my sunshine.mp4'], | |
['videos/scrub your hands for at least 20 seconds.mp4'], | |
['videos/no.mp4'], | |
['videos/all.mp4'] | |
['videos/white.mp4'] | |
] | |
# examples = [ | |
# ["videos/My second ASL professor's name was Will White.mp4", "My second ASL professor's name was Will White"], | |
# ['videos/You are my sunshine.mp4', 'You are my sunshine'], | |
# ['videos/scrub your hands for at least 20 seconds.mp4', 'scrub your hands for at least 20 seconds'], | |
# ['videos/no.mp4', 'no'], | |
# ['videos/all.mp4', 'all'] | |
# ['videos/white.mp4', 'white'] | |
# ] | |
description = "Gradio demo of word-level sign language classification using I3D model pretrained on the WLASL video dataset. " \ | |
"WLASL is a large-scale dataset containing more than 2000 words in American Sign Language. " \ | |
"Examples used in the demo are videos from the the test subset. " \ | |
"Note that WLASL100 contains 100 words while WLASL2000 contains 2000." | |
article = "More information about the trained models can be found <a href=https://github.com/deanna-emery/ASL-Translator/>here</a>." | |
# Gradio App interface | |
gr.Interface(fn=translate, | |
inputs="video", | |
outputs="text", | |
allow_flagging="never", | |
title=title, | |
description=description, | |
examples=examples, | |
article=article).launch() | |