File size: 2,287 Bytes
c1a75e2 8ffd571 c1a75e2 8ffd571 c1a75e2 8ffd571 90f6e7b 8ffd571 bace7c0 8ffd571 90f6e7b 8ffd571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import gradio as gr
import numpy as np
import zipfile
import imageio
import tensorflow as tf
from tensorflow import keras
from utils import read_video, frame_sampling
from utils import num_frames, patch_size, input_size
from labels import K400_label_map, SSv2_label_map
LABEL_MAPS = {
'K400': K400_label_map,
'SSv2': SSv2_label_map,
}
ALL_MODELS = [
'TFVideoSwinT_K400_IN1K_P244_W877_32x224',
'TFVideoSwinB_SSV2_K400_P244_W1677_32x224',
]
sample_example = [
["examples/k400.mp4", ALL_MODELS[0]],
["examples/ssv2.mp4", ALL_MODELS[1]],
]
def get_model(model_type):
model_path = keras.utils.get_file(
origin=f'https://github.com/innat/VideoSwin/releases/download/v1.1/{model_type}.zip',
)
with zipfile.ZipFile(model_path, 'r') as zip_ref:
zip_ref.extractall('./')
model = keras.models.load_model(model_type)
if 'W877' in model_type:
data_type = 'K400'
else:
data_type = 'SSv2'
label_map = LABEL_MAPS.get(data_type)
label_map = {v: k for k, v in label_map.items()}
return model, label_map
def inference(video_file, model_type):
# get sample data
container = read_video(video_file)
frames = frame_sampling(container, num_frames=num_frames)
# get models
model, label_map = get_model(model_type)
model.trainable = False
# inference on model
outputs = model(frames[None, ...], training=False)
probabilities = tf.nn.softmax(outputs).numpy().squeeze(0)
confidences = {
label_map[i]: float(probabilities[i]) for i in np.argsort(probabilities)[::-1]
}
return confidences
def main():
iface = gr.Interface(
fn=inference,
inputs=[
gr.Video(type="file", label="Input Video"),
gr.Dropdown(
choices=ALL_MODELS,
default="TFVideoSwinT_K400_IN1K_P244_W877_32x224",
label="Model"
)
],
outputs=gr.Label(num_top_classes=3, label='scores'),
examples=sample_example,
title="VideoSwin: Video Swin Transformer",
description="Keras reimplementation of <a href='https://github.com/innat/VideoSwin'>VideoSwin</a> is presented here."
)
iface.launch()
if __name__ == '__main__':
main() |