Spaces:
Runtime error
Runtime error
import cv2 | |
import numpy as np | |
import gradio as gr | |
# import os | |
# os.chdir('modeling') | |
import tensorflow as tf, tf_keras | |
import tensorflow_hub as hub | |
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM | |
from official.projects.movinet.modeling import movinet | |
from official.projects.movinet.modeling import movinet_model_a2_modified as movinet_model_modified | |
movinet_path = 'movinet_checkpoints_a2_epoch9' | |
movinet_model = tf_keras.models.load_model(movinet_path) | |
movinet_model.trainable = False | |
tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
t5_model = TFAutoModelForSeq2SeqLM.from_pretrained("deanna-emery/ASL_t5_movinet_sentence") | |
t5_model.trainable = False | |
def crop_center_square(frame): | |
y, x = frame.shape[0:2] | |
if x > y: | |
start_x = (x-y)/2 | |
end_x = start_x + y | |
start_x = int(start_x) | |
end_x = int(end_x) | |
return frame[:, int(start_x):int(end_x)] | |
else: | |
return frame | |
def preprocess(filename, max_frames=0, resize=(224,224)): | |
video_capture = cv2.VideoCapture(filename) | |
frames = [] | |
try: | |
while video_capture.isOpened(): | |
ret, frame = video_capture.read() | |
if not ret: | |
break | |
frame = crop_center_square(frame) | |
frame = cv2.resize(frame, resize) | |
frame = frame[:, :, [2, 1, 0]] | |
frames.append(frame) | |
if len(frames) == max_frames: | |
break | |
finally: | |
video_capture.release() | |
video = np.array(frames) / 255.0 | |
video = np.expand_dims(video, axis=0) | |
return video | |
def translate(video_file, true_caption=None): | |
video = preprocess(video_file, max_frames=0, resize=(224,224)) | |
embeddings = movinet_model(video)['vid_embedding'] | |
tokens = t5_model.generate(inputs_embeds = embeddings, | |
max_new_tokens=128, | |
temperature=0.1, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
top_k=80, | |
top_p=0.90, | |
) | |
translation = tokenizer.batch_decode(tokens, skip_special_tokens=True) | |
return {"translation":translation} | |
# Gradio App config | |
title = "American Sign Language Translation: An Approach Combining MoViNets and T5" | |
description = """ | |
This application hosts a model for translation of American Sign Language (ASL). | |
The model comprises of a fine-tuned MoViNet CNN model to generate video embeddings and a T5 encoder-decoder model | |
to generate translations from the video embeddings. This model architecture achieves a BLEU score of 1.98 | |
and an average cosine similarity score of 0.21 when trained and evaluated on the YouTube-ASL dataset. | |
More information about the model training and instructions to download the models | |
can be found in our <a href=https://github.com/deanna-emery/ASL-Translator>GitHub repository</a>. | |
You can also find an overview of the project approach | |
<a href=https://www.ischool.berkeley.edu/projects/2023/signsense-american-sign-language-translation>here</a>. | |
A limitation of this architecture is the size of the MoViNets model, making it especially slow during inference on a CPU. | |
We do not recommend uploading videos longer than 4 seconds as the video embedding generation may take some time. | |
The application does not accept videos that are longer than 10 seconds. | |
We have provided some pre-cached videos with their original captions and translations as examples. | |
""" | |
examples = [ | |
["videos/My_second_ASL_professors_name_was_Will_White.mp4", "My second ASL professor's name was Will White"], | |
['videos/You_are_my_sunshine.mp4', 'You are my sunshine'], | |
['videos/scrub_your_hands_for_at_least_20_seconds.mp4', 'scrub your hands for at least 20 seconds'], | |
['videos/no.mp4', 'no'], | |
["videos/i_feel_rejuvenated_by_this_beautiful_weather.mp4","I feel rejuvenated by this beautiful weather"], | |
["videos/north_dakota_they_dont_need.mp4","... north dakota they don't need ..."], | |
] | |
# Gradio App interface | |
gr.Interface(fn=translate, | |
inputs=[gr.Video(label='Video', show_label=True, max_length=10, sources='upload'), | |
gr.Textbox(label='Caption', show_label=True, interactive=False, visible=False)], | |
outputs="text", | |
allow_flagging="never", | |
title=title, | |
description=description, | |
examples=examples, | |
).launch() |