File size: 2,740 Bytes
79a2238
 
 
 
dbc6d1e
 
 
79a2238
 
 
 
93528c6
 
79a2238
 
0a01a25
79a2238
 
 
 
78f4e06
79a2238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a3746d
79a2238
 
 
 
 
 
 
 
 
 
 
 
 
 
6a3746d
2a0abe4
79a2238
 
649988d
9cbe885
649988d
 
 
9cbe885
79a2238
1ef90ad
6a3746d
b907f86
1ef90ad
 
 
4b278b1
 
6a3746d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import cv2
import numpy as np
import gradio as gr

# import os
# os.chdir('modeling')

import tensorflow as tf, tf_keras
import tensorflow_hub as hub
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model_a2_modified as movinet_model_modified


movinet_path = 'movinet_checkpoints_a2_epoch9'
movinet_model = tf_keras.models.load_model(movinet_path)
movinet_model.trainable = False

tokenizer = AutoTokenizer.from_pretrained("t5-base")
t5_model = TFAutoModelForSeq2SeqLM.from_pretrained("deanna-emery/ASL_t5_word_epoch15_1204")
t5_model.trainable = False

def crop_center_square(frame):
    y, x = frame.shape[0:2]
    if x > y:
        start_x = (x-y)/2
        end_x = start_x + y
        start_x = int(start_x)
        end_x = int(end_x)
        return frame[:, int(start_x):int(end_x)]
    else:
        return frame
    

def preprocess(filename, max_frames=0, resize=(224,224)):
    video_capture = cv2.VideoCapture(filename)
    frames = []
    try:
      while video_capture.isOpened():
        ret, frame = video_capture.read()
        if not ret:
          break
        frame = crop_center_square(frame)
        frame = cv2.resize(frame, resize)
        frame = frame[:, :, [2, 1, 0]]
        frames.append(frame)

        if len(frames) == max_frames:
          break
    finally:
      video_capture.release()

    video = np.array(frames) / 255.0
    video = np.expand_dims(video, axis=0)
    return video

def translate(video_file, true_caption=None):

    video = preprocess(video_file, max_frames=0, resize=(224,224))

    embeddings = movinet_model(video)['vid_embedding']
    tokens = t5_model.generate(inputs_embeds = embeddings,
                               max_new_tokens=128, 
                                temperature=0.1,
                                no_repeat_ngram_size=2,
                                do_sample=True,
                                top_k=80, 
                                top_p=0.90,
                                ) 
    
    translation = tokenizer.batch_decode(tokens, skip_special_tokens=True)
  
    return {"translation":translation}

# Gradio App config
title = ""

description =   """"""

examples = []

# Gradio App interface
gr.Interface(fn=translate,
              inputs=[gr.Video(label='Video', show_label=True, max_length=10, sources='upload'), 
                      gr.Textbox(label='Caption', show_label=True, interactive=False, visible=False)], 
              outputs="text",
              allow_flagging="never",
              title=title, 
              description=description,
              examples=examples,
              ).launch()