File size: 4,363 Bytes
79a2238
 
 
 
dbc6d1e
 
 
79a2238
 
 
 
93528c6
 
79a2238
 
0a01a25
79a2238
 
 
 
3005acd
79a2238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a3746d
79a2238
 
 
 
 
 
 
 
 
 
 
 
 
 
6a3746d
2a0abe4
79a2238
 
09c1a2c
 
 
b907f86
 
09c1a2c
 
31d4007
09c1a2c
 
 
d2377c0
09c1a2c
 
9cbe885
 
79a2238
53a99e5
9cbe885
 
 
79a2238
9cbe885
 
09c1a2c
 
 
 
 
 
 
79a2238
1ef90ad
6a3746d
b907f86
1ef90ad
 
 
4b278b1
 
6a3746d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import cv2
import numpy as np
import gradio as gr

# import os
# os.chdir('modeling')

import tensorflow as tf, tf_keras
import tensorflow_hub as hub
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model_a2_modified as movinet_model_modified


movinet_path = 'movinet_checkpoints_a2_epoch9'
movinet_model = tf_keras.models.load_model(movinet_path)
movinet_model.trainable = False

tokenizer = AutoTokenizer.from_pretrained("t5-base")
t5_model = TFAutoModelForSeq2SeqLM.from_pretrained("deanna-emery/ASL_t5_movinet_sentence")
t5_model.trainable = False

def crop_center_square(frame):
    y, x = frame.shape[0:2]
    if x > y:
        start_x = (x-y)/2
        end_x = start_x + y
        start_x = int(start_x)
        end_x = int(end_x)
        return frame[:, int(start_x):int(end_x)]
    else:
        return frame
    

def preprocess(filename, max_frames=0, resize=(224,224)):
    video_capture = cv2.VideoCapture(filename)
    frames = []
    try:
      while video_capture.isOpened():
        ret, frame = video_capture.read()
        if not ret:
          break
        frame = crop_center_square(frame)
        frame = cv2.resize(frame, resize)
        frame = frame[:, :, [2, 1, 0]]
        frames.append(frame)

        if len(frames) == max_frames:
          break
    finally:
      video_capture.release()

    video = np.array(frames) / 255.0
    video = np.expand_dims(video, axis=0)
    return video

def translate(video_file, true_caption=None):

    video = preprocess(video_file, max_frames=0, resize=(224,224))

    embeddings = movinet_model(video)['vid_embedding']
    tokens = t5_model.generate(inputs_embeds = embeddings,
                               max_new_tokens=128, 
                                temperature=0.1,
                                no_repeat_ngram_size=2,
                                do_sample=True,
                                top_k=80, 
                                top_p=0.90,
                                ) 
    
    translation = tokenizer.batch_decode(tokens, skip_special_tokens=True)
  
    return {"translation":translation}

# Gradio App config
title = "American Sign Language Translation: An Approach Combining MoViNets and T5"

description =   """
This application surfaces a model for translation of American Sign Language (ASL). 
The model comprises of a fine-tuned MoViNet CNN model to generate video embeddings and a T5 encoder-decoder model 
to generate translations from the video embeddings. This model architecture achieves a BLEU score of 1.98 
and an average cosine similarity score of 0.21 when trained and evaluated on the YouTube-ASL dataset. 
More information about the model training and instructions to download the models can be found in our GitHub repository <a href=https://github.com/deanna-emery/ASL-Translator>here</a>.

A limitation of this architecture is the size of the MoViNets model, making it especially slow during inference on a CPU. 
We do not recommend uploading videos longer than 4 seconds as the video embedding generation may take some time. 
The application does not accept videos that are longer than 10 seconds.
We have provided some pre-cached videos with their original captions and translations as examples.
"""


examples = [
        ["videos/My_second_ASL_professors_name_was_Will_White.mp4", "My second ASL professor's name was Will White"],
        ['videos/You_are_my_sunshine.mp4', 'You are my sunshine'],
        ['videos/scrub_your_hands_for_at_least_20_seconds.mp4', 'scrub your hands for at least 20 seconds'],
        ['videos/no.mp4', 'no'],
    ]


article =  """The captions for the example videos are as follows in order: \n
1. 'My second ASL professor's name was Will White'
2. 'You are my sunshine'
3. 'scrub your hands for at least 20 seconds'
4. 'no'
"""

# Gradio App interface
gr.Interface(fn=translate,
              inputs=[gr.Video(label='Video', show_label=True, max_length=10, sources='upload'), 
                      gr.Textbox(label='Caption', show_label=True, interactive=False, visible=False)], 
              outputs="text",
              allow_flagging="never",
              title=title, 
              description=description,
              examples=examples,
              ).launch()