|
import gradio as gr |
|
import numpy as np |
|
import zipfile |
|
import imageio |
|
|
|
import tensorflow as tf |
|
from tensorflow import keras |
|
|
|
from utils import read_video, frame_sampling |
|
from utils import num_frames, patch_size, input_size |
|
from labels import K400_label_map, SSv2_label_map |
|
|
|
|
|
LABEL_MAPS = { |
|
'K400': K400_label_map, |
|
'SSv2': SSv2_label_map, |
|
} |
|
|
|
ALL_MODELS = [ |
|
'TFVideoSwinT_K400_IN1K_P244_W877_32x224', |
|
'TFVideoSwinB_SSV2_K400_P244_W1677_32x224', |
|
] |
|
|
|
sample_example = [ |
|
["examples/k400.mp4", ALL_MODELS[0]], |
|
["examples/ssv2.mp4", ALL_MODELS[1]], |
|
] |
|
|
|
|
|
def get_model(model_type): |
|
model_path = keras.utils.get_file( |
|
origin=f'https://github.com/innat/VideoSwin/releases/download/v1.1/{model_type}.zip', |
|
) |
|
with zipfile.ZipFile(model_path, 'r') as zip_ref: |
|
zip_ref.extractall('./') |
|
|
|
model = keras.models.load_model(model_type) |
|
|
|
if 'K400' in model_type: |
|
data_type = 'K400' |
|
else: |
|
data_type = 'SSv2' |
|
|
|
label_map = LABEL_MAPS.get(data_type) |
|
label_map = {v: k for k, v in label_map.items()} |
|
|
|
return model, label_map |
|
|
|
|
|
def inference(video_file, model_type): |
|
|
|
container = read_video(video_file) |
|
frames = frame_sampling(container, num_frames=num_frames) |
|
|
|
|
|
model, label_map = get_model(model_type) |
|
model.trainable = False |
|
|
|
|
|
outputs = model(frames[None, ...], training=False) |
|
probabilities = tf.nn.softmax(outputs).numpy().squeeze(0) |
|
confidences = { |
|
label_map[i]: float(probabilities[i]) for i in np.argsort(probabilities)[::-1] |
|
} |
|
return confidences |
|
|
|
|
|
def main(): |
|
iface = gr.Interface( |
|
fn=inference, |
|
inputs=[ |
|
gr.Video(type="file", label="Input Video"), |
|
gr.Dropdown( |
|
choices=ALL_MODELS, |
|
default="TFVideoSwinT_K400_IN1K_P244_W877_32x224", |
|
label="Model" |
|
) |
|
], |
|
outputs=[ |
|
gr.Label(num_top_classes=3, label='scores'), |
|
], |
|
examples=sample_example, |
|
title="VideoSwin: Video Swin Transformer", |
|
description="Keras reimplementation of <a href='https://github.com/innat/VideoSwin'>VideoSwin</a> is presented here." |
|
) |
|
iface.launch() |
|
|
|
if __name__ == '__main__': |
|
main() |