File size: 2,782 Bytes
79a2238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import cv2
import numpy as np
import gradio as gr

import os
os.chdir('models')

import tensorflow as tf, tf_keras
import tensorflow_hub as hub
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model_a2_modified as movinet_model_modified


movinet_path = 'movinet_checkpoints_a2_epoch9'
movinet_model = tf_keras.models.load_model(movinet_path)
movinet_model.trainable = False

tokenizer = AutoTokenizer.from_pretrained("t5-base")
t5_model = TFAutoModelForSeq2SeqLM.from_pretrained("deanna-emery/t5_word_epoch12_1203")
t5_model.trainable = False

def crop_center_square(frame):
    y, x = frame.shape[0:2]
    if x > y:
        start_x = (x-y)/2
        end_x = start_x + y
        start_x = int(start_x)
        end_x = int(end_x)
        return frame[:, int(start_x):int(end_x)]
    else:
        return frame
    

def preprocess(filename, max_frames=0, resize=(224,224)):
    video_capture = cv2.VideoCapture(filename)
    frames = []
    try:
      while video_capture.isOpened():
        ret, frame = video_capture.read()
        if not ret:
          break
        frame = crop_center_square(frame)
        frame = cv2.resize(frame, resize)
        frame = frame[:, :, [2, 1, 0]]
        frames.append(frame)

        if len(frames) == max_frames:
          break
    finally:
      video_capture.release()

    video = np.array(frames) / 255.0
    video = np.expand_dims(video, axis=0)
    return video

def translate(video_file):

    video = preprocess(video_file, max_frames=0, resize=(224,224))

    embeddings = movinet_model(video)['vid_embedding']
    tokens = t5_model.generate(inputs_embeds = embeddings,
                               max_new_tokens=128, 
                                temperature=0.1,
                                no_repeat_ngram_size=2,
                                do_sample=True,
                                top_k=80, 
                                top_p=0.90,
                                ) 
    
    translation = tokenizer.batch_decode(tokens, skip_special_tokens=True)

    # Return dict {label:pred}
    return {"translation":translation}

# Gradio App config
title = "ASL Translation (MoViNet + T5)"
examples = [
        ['videos/no.mp4'],
        ['videos/all.mp4'],
        ['videos/before.mp4'],
        ['videos/blue.mp4'],
        ['videos/white.mp4'],
        ['videos/accident2.mp4']
    ]

# Gradio App interface
gr.Interface(   fn=translate,
                inputs=[gr.inputs.Video(label="Video (*.mp4)")], 
                outputs=[gr.outputs.Label(label='Translation')],
                allow_flagging="never",
                title=title, 
                examples=examples).launch()