innat commited on
Commit
3ee2d3e
·
1 Parent(s): 0b2cebb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -53
app.py CHANGED
@@ -11,23 +11,7 @@ from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_siz
11
  from labels import K400_label_map, SSv2_label_map, UCF_label_map
12
 
13
 
14
- MODELS = {
15
- 'K400': [
16
- './TFVideoMAE_S_K400_16x224_FT',
17
- './TFVideoMAE_S_K400_16x224_PT'
18
- ],
19
- 'SSv2': [
20
- './TFVideoMAE_S_K400_16x224_FT',
21
- './TFVideoMAE_S_K400_16x224_PT'
22
- ],
23
- 'UCF' : [
24
- 'innat/videomae/TFVideoMAE_S_K400_16x224_FT',
25
- './TFVideoMAE_S_K400_16x224_PT'
26
- ]
27
- }
28
-
29
-
30
- def tube_mask_generator():
31
  window_size = (
32
  num_frames // 2,
33
  input_size // patch_size[0],
@@ -35,7 +19,7 @@ def tube_mask_generator():
35
  )
36
  tube_mask = TubeMaskingGenerator(
37
  input_size=window_size,
38
- mask_ratio=0.70
39
  )
40
  make_bool = tube_mask()
41
  bool_masked_pos_tf = tf.constant(make_bool, dtype=tf.int32)
@@ -44,28 +28,17 @@ def tube_mask_generator():
44
  return bool_masked_pos_tf
45
 
46
 
47
- def video_to_gif(video_array, gif_filename):
48
- imageio.mimsave(
49
- gif_filename, video_array, duration=100
50
- )
51
-
52
-
53
  def get_model(data_type):
54
- print()
55
- print('-------------------- ', data_type)
56
- print()
57
-
58
- data_type ='K400'
59
  ft_model = keras.models.load_model(MODELS[data_type][0])
60
  pt_model = keras.models.load_model(MODELS[data_type][1])
61
  label_map = {v: k for k, v in K400_label_map.items()}
62
  return ft_model, pt_model, label_map
63
 
64
 
65
- def inference(video_file, dataset_type):
66
  container = read_video(video_file)
67
  frames = frame_sampling(container, num_frames=num_frames)
68
- bool_masked_pos_tf = tube_mask_generator()
69
  ft_model, pt_model, label_map = get_model(dataset_type)
70
  ft_model.trainable = False
71
  pt_model.trainable = False
@@ -97,25 +70,57 @@ def inference(video_file, dataset_type):
97
  return confidences, combined_gif
98
 
99
 
100
- gr.Interface(
101
- fn=inference,
102
- inputs=[
103
- gr.Video(type="file"),
104
- gr.Radio(
105
- ['K400', 'SSv2', 'UCF'],
106
- type='value',
107
- default='K400',
108
- label='Dataset',
109
- ),
110
- ],
111
- outputs=[
112
- gr.Label(num_top_classes=3, label='confidence scores'),
113
- gr.Image(type="filepath", label='reconstructed masked autoencoder')
114
- ],
115
- examples=[
116
- ["examples/k400.mp4"],
117
- ["examples/k400.mp4"],
118
- ["examples/k400.mp4"],
119
- ],
120
- title="VideoMAE",
121
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from labels import K400_label_map, SSv2_label_map, UCF_label_map
12
 
13
 
14
+ def tube_mask_generator(mask_ratio):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  window_size = (
16
  num_frames // 2,
17
  input_size // patch_size[0],
 
19
  )
20
  tube_mask = TubeMaskingGenerator(
21
  input_size=window_size,
22
+ mask_ratio=mask_ratio
23
  )
24
  make_bool = tube_mask()
25
  bool_masked_pos_tf = tf.constant(make_bool, dtype=tf.int32)
 
28
  return bool_masked_pos_tf
29
 
30
 
 
 
 
 
 
 
31
  def get_model(data_type):
 
 
 
 
 
32
  ft_model = keras.models.load_model(MODELS[data_type][0])
33
  pt_model = keras.models.load_model(MODELS[data_type][1])
34
  label_map = {v: k for k, v in K400_label_map.items()}
35
  return ft_model, pt_model, label_map
36
 
37
 
38
+ def inference(video_file, dataset_type, mask_ratio):
39
  container = read_video(video_file)
40
  frames = frame_sampling(container, num_frames=num_frames)
41
+ bool_masked_pos_tf = tube_mask_generator(mask_ratio)
42
  ft_model, pt_model, label_map = get_model(dataset_type)
43
  ft_model.trainable = False
44
  pt_model.trainable = False
 
70
  return confidences, combined_gif
71
 
72
 
73
+ def main():
74
+ MODELS = {
75
+ 'K400': [
76
+ './TFVideoMAE_S_K400_16x224_FT',
77
+ './TFVideoMAE_S_K400_16x224_PT'
78
+ ],
79
+ 'SSv2': [
80
+ './TFVideoMAE_S_K400_16x224_FT',
81
+ './TFVideoMAE_S_K400_16x224_PT'
82
+ ],
83
+ 'UCF' : [
84
+ 'innat/videomae/TFVideoMAE_S_K400_16x224_FT',
85
+ './TFVideoMAE_S_K400_16x224_PT'
86
+ ]
87
+ }
88
+
89
+ BENCHMARK_DATASETS = ['K400', 'SSv2', 'UCF']
90
+ SAMPLE_EXAMPLES = [
91
+ ["examples/k400.mp4", 'Kintetics-400'],
92
+ ["examples/k400.mp4", 'SSv2'],
93
+ ["examples/k400.mp4", 'UCF']
94
+ ]
95
+
96
+ iface = gr.Interface(
97
+ fn=inference,
98
+ inputs=[
99
+ gr.Video(type="file", label="Input Video"),
100
+ gr.Radio(
101
+ BENCHMARK_DATASETS,
102
+ type='value',
103
+ default=BENCHMARK_DATASETS[0],
104
+ label='Dataset',
105
+ ),
106
+ gr.inputs.Slider(
107
+ 0.5,
108
+ 1.0,
109
+ step=0.1,
110
+ default=0.7,
111
+ label='Mask Ratio'
112
+ )
113
+ ],
114
+ outputs=[
115
+ gr.Label(num_top_classes=3, label='scores'),
116
+ gr.Image(type="filepath", label='reconstructed')
117
+ ],
118
+ examples=SAMPLE_EXAMPLES,
119
+ title="VideoMAE",
120
+ description="Keras reimplementation of <a href='https://github.com/innat/VideoMAE'>VideoMAE</a> is presented here."
121
+ )
122
+
123
+ iface.launch()
124
+
125
+ if __name__ == '__main__':
126
+ main()