import cv2 import numpy as np import gradio as gr # import os # os.chdir('modeling') import tensorflow as tf, tf_keras import tensorflow_hub as hub from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM from official.projects.movinet.modeling import movinet from official.projects.movinet.modeling import movinet_model_a2_modified as movinet_model_modified movinet_path = 'movinet_checkpoints_a2_epoch9' movinet_model = tf_keras.models.load_model(movinet_path) movinet_model.trainable = False tokenizer = AutoTokenizer.from_pretrained("t5-base") t5_model = TFAutoModelForSeq2SeqLM.from_pretrained("deanna-emery/ASL_t5_word_epoch15_1204") t5_model.trainable = False def crop_center_square(frame): y, x = frame.shape[0:2] if x > y: start_x = (x-y)/2 end_x = start_x + y start_x = int(start_x) end_x = int(end_x) return frame[:, int(start_x):int(end_x)] else: return frame def preprocess(filename, max_frames=0, resize=(224,224)): video_capture = cv2.VideoCapture(filename) frames = [] try: while video_capture.isOpened(): ret, frame = video_capture.read() if not ret: break frame = crop_center_square(frame) frame = cv2.resize(frame, resize) frame = frame[:, :, [2, 1, 0]] frames.append(frame) if len(frames) == max_frames: break finally: video_capture.release() video = np.array(frames) / 255.0 video = np.expand_dims(video, axis=0) return video def translate(video_file, true_caption=None): video = preprocess(video_file, max_frames=0, resize=(224,224)) embeddings = movinet_model(video)['vid_embedding'] tokens = t5_model.generate(inputs_embeds = embeddings, max_new_tokens=128, temperature=0.1, no_repeat_ngram_size=2, do_sample=True, top_k=80, top_p=0.90, ) translation = tokenizer.batch_decode(tokens, skip_special_tokens=True) return {"translation":translation} # Gradio App config title = "" description = """""" examples = [] # Gradio App interface gr.Interface(fn=translate, inputs=[gr.Video(label='Video', show_label=True, max_length=10, sources='upload'), gr.Textbox(label='Caption', show_label=True, interactive=False, visible=False)], outputs="text", allow_flagging="never", title=title, description=description, examples=examples, ).launch()