File size: 3,788 Bytes
79a2238
 
 
 
 
 
 
 
93528c6
 
79a2238
 
0a01a25
79a2238
 
 
 
3005acd
79a2238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a0abe4
79a2238
 
 
 
a7a99aa
 
 
 
 
 
79a2238
 
a7a99aa
 
 
 
 
 
 
 
 
3005acd
 
 
 
 
 
 
 
 
79a2238
1ef90ad
ceb9afe
1ef90ad
 
 
3005acd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import cv2
import numpy as np
import gradio as gr

import tensorflow as tf, tf_keras
import tensorflow_hub as hub
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model_a2_modified as movinet_model_modified


movinet_path = 'movinet_checkpoints_a2_epoch9'
movinet_model = tf_keras.models.load_model(movinet_path)
movinet_model.trainable = False

tokenizer = AutoTokenizer.from_pretrained("t5-base")
t5_model = TFAutoModelForSeq2SeqLM.from_pretrained("deanna-emery/ASL_t5_movinet_sentence")
t5_model.trainable = False

def crop_center_square(frame):
    y, x = frame.shape[0:2]
    if x > y:
        start_x = (x-y)/2
        end_x = start_x + y
        start_x = int(start_x)
        end_x = int(end_x)
        return frame[:, int(start_x):int(end_x)]
    else:
        return frame
    

def preprocess(filename, max_frames=0, resize=(224,224)):
    video_capture = cv2.VideoCapture(filename)
    frames = []
    try:
      while video_capture.isOpened():
        ret, frame = video_capture.read()
        if not ret:
          break
        frame = crop_center_square(frame)
        frame = cv2.resize(frame, resize)
        frame = frame[:, :, [2, 1, 0]]
        frames.append(frame)

        if len(frames) == max_frames:
          break
    finally:
      video_capture.release()

    video = np.array(frames) / 255.0
    video = np.expand_dims(video, axis=0)
    return video

def translate(video_file):

    video = preprocess(video_file, max_frames=0, resize=(224,224))

    embeddings = movinet_model(video)['vid_embedding']
    tokens = t5_model.generate(inputs_embeds = embeddings,
                               max_new_tokens=128, 
                                temperature=0.1,
                                no_repeat_ngram_size=2,
                                do_sample=True,
                                top_k=80, 
                                top_p=0.90,
                                ) 
    
    translation = tokenizer.batch_decode(tokens, skip_special_tokens=True)

    return {"translation":translation}

# Gradio App config
title = "ASL Translation (MoViNet + T5)"
examples = [
        ["videos/My second ASL professor's name was Will White.mp4"],
        ['videos/You are my sunshine.mp4'],
        ['videos/scrub your hands for at least 20 seconds.mp4'],
        ['videos/no.mp4'],
        ['videos/all.mp4']
        ['videos/white.mp4']
    ]

# examples = [
#         ["videos/My second ASL professor's name was Will White.mp4", "My second ASL professor's name was Will White"],
#         ['videos/You are my sunshine.mp4', 'You are my sunshine'],
#         ['videos/scrub your hands for at least 20 seconds.mp4', 'scrub your hands for at least 20 seconds'],
#         ['videos/no.mp4', 'no'],
#         ['videos/all.mp4', 'all']
#         ['videos/white.mp4', 'white']
#     ]

description =   "Gradio demo of word-level sign language classification using I3D model pretrained on the WLASL video dataset. " \
                "WLASL is a large-scale dataset containing more than 2000 words in American Sign Language. " \
                "Examples used in the demo are videos from the the test subset. "  \
                "Note that WLASL100 contains 100 words while WLASL2000 contains 2000."


article =   "More information about the trained models can be found <a href=https://github.com/deanna-emery/ASL-Translator/>here</a>."


# Gradio App interface
gr.Interface(fn=translate,
              inputs="video", 
              outputs="text",
              allow_flagging="never",
              title=title, 
              description=description,
              examples=examples,
              article=article).launch()