import cv2 import numpy as np import gradio as gr # import os # os.chdir('modeling') import tensorflow as tf, tf_keras import tensorflow_hub as hub from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM from official.projects.movinet.modeling import movinet from official.projects.movinet.modeling import movinet_model_a2_modified as movinet_model_modified movinet_path = 'modeling/movinet_checkpoints_a2_epoch9' movinet_model = tf_keras.models.load_model(movinet_path) movinet_model.trainable = False tokenizer = AutoTokenizer.from_pretrained("t5-base") t5_model = TFAutoModelForSeq2SeqLM.from_pretrained("deanna-emery/ASL_t5_word_epoch15_1204") t5_model.trainable = False def crop_center_square(frame): y, x = frame.shape[0:2] if x > y: start_x = (x-y)/2 end_x = start_x + y start_x = int(start_x) end_x = int(end_x) return frame[:, int(start_x):int(end_x)] else: return frame def preprocess(filename, max_frames=0, resize=(224,224)): video_capture = cv2.VideoCapture(filename) frames = [] try: while video_capture.isOpened(): ret, frame = video_capture.read() if not ret: break frame = crop_center_square(frame) frame = cv2.resize(frame, resize) frame = frame[:, :, [2, 1, 0]] frames.append(frame) if len(frames) == max_frames: break finally: video_capture.release() video = np.array(frames) / 255.0 video = np.expand_dims(video, axis=0) return video def translate(video_file): video = preprocess(video_file, max_frames=0, resize=(224,224)) embeddings = movinet_model(video)['vid_embedding'] tokens = t5_model.generate(inputs_embeds = embeddings, max_new_tokens=128, temperature=0.1, no_repeat_ngram_size=2, do_sample=True, top_k=80, top_p=0.90, ) translation = tokenizer.batch_decode(tokens, skip_special_tokens=True) # Return dict {label:pred} return {"translation":translation} # Gradio App config title = "ASL Translation (MoViNet + T5)" examples = [ ['videos/no.mp4'], ['videos/all.mp4'], ['videos/before.mp4'], ['videos/blue.mp4'], ['videos/white.mp4'], ['videos/accident2.mp4'] ] # Gradio App interface gr.Interface( fn=translate, inputs=[gr.inputs.Video(label="Video (*.mp4)")], outputs=[gr.outputs.Label(label='Translation')], allow_flagging="never", title=title, examples=examples).launch()