innat commited on
Commit
222819d
1 Parent(s): 6022c8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -89
app.py CHANGED
@@ -6,114 +6,59 @@ import imageio
6
  import tensorflow as tf
7
  from tensorflow import keras
8
 
9
- from utils import TubeMaskingGenerator
10
- from utils import read_video, frame_sampling, denormalize, reconstrunction
11
- from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_size
12
- from labels import K400_label_map, SSv2_label_map, UCF_label_map
13
 
14
 
15
  LABEL_MAPS = {
16
  'K400': K400_label_map,
17
- 'SSv2': SSv2_label_map,
18
- 'UCF' : UCF_label_map
19
  }
20
 
21
  ALL_MODELS = [
22
- 'TFVideoMAE_L_K400_16x224',
23
- 'TFVideoMAE_B_SSv2_16x224',
24
- 'TFVideoMAE_B_UCF_16x224',
25
  ]
26
 
27
  sample_example = [
28
- ["examples/k400.mp4", ALL_MODELS[0], 0.9],
29
- ["examples/ssv2.mp4", ALL_MODELS[1], 0.8],
30
- ["examples/ucf.mp4", ALL_MODELS[2], 0.7],
31
  ]
32
 
33
- def tube_mask_generator(mask_ratio):
34
- window_size = (
35
- num_frames // 2,
36
- input_size // patch_size[0],
37
- input_size // patch_size[1]
38
- )
39
- tube_mask = TubeMaskingGenerator(
40
- input_size=window_size,
41
- mask_ratio=mask_ratio
42
- )
43
- make_bool = tube_mask()
44
- bool_masked_pos_tf = tf.constant(make_bool, dtype=tf.int32)
45
- bool_masked_pos_tf = tf.expand_dims(bool_masked_pos_tf, axis=0)
46
- bool_masked_pos_tf = tf.cast(bool_masked_pos_tf, tf.bool)
47
- return bool_masked_pos_tf
48
-
49
-
50
  def get_model(model_type):
51
- ft_path = keras.utils.get_file(
52
- origin=f'https://github.com/innat/VideoMAE/releases/download/v1.1/{model_type}_FT.zip',
53
- )
54
- pt_path = keras.utils.get_file(
55
- origin=f'https://github.com/innat/VideoMAE/releases/download/v1.1/{model_type}_PT.zip',
56
  )
57
-
58
- with zipfile.ZipFile(ft_path, 'r') as zip_ref:
59
  zip_ref.extractall('./')
60
 
61
- with zipfile.ZipFile(pt_path, 'r') as zip_ref:
62
- zip_ref.extractall('./')
63
-
64
- ft_model = keras.models.load_model(model_type + '_FT')
65
- pt_model = keras.models.load_model(model_type + '_PT')
66
 
67
  if 'K400' in model_type:
68
  data_type = 'K400'
69
- elif 'SSv2' in model_type:
70
- data_type = 'SSv2'
71
  else:
72
- data_type = 'UCF'
73
 
74
  label_map = LABEL_MAPS.get(data_type)
75
  label_map = {v: k for k, v in label_map.items()}
76
 
77
- return ft_model, pt_model, label_map
78
 
79
 
80
- def inference(video_file, model_type, mask_ratio):
81
  # get sample data
82
  container = read_video(video_file)
83
  frames = frame_sampling(container, num_frames=num_frames)
84
 
85
  # get models
86
- bool_masked_pos_tf = tube_mask_generator(mask_ratio)
87
- ft_model, pt_model, label_map = get_model(model_type)
88
- ft_model.trainable = False
89
- pt_model.trainable = False
90
-
91
- # inference on fine-tune model
92
- outputs_ft = ft_model(frames[None, ...], training=False)
93
- probabilities = tf.nn.softmax(outputs_ft).numpy().squeeze(0)
94
  confidences = {
95
  label_map[i]: float(probabilities[i]) for i in np.argsort(probabilities)[::-1]
96
  }
97
-
98
- # inference on pre-trained model
99
- outputs_pt = pt_model(frames[None, ...], bool_masked_pos_tf, training=False)
100
- reconstruct_output, mask = reconstrunction(
101
- frames[None, ...], bool_masked_pos_tf, outputs_pt
102
- )
103
-
104
- # post process
105
- input_frame = denormalize(frames)
106
- input_mask = denormalize(mask[0] * frames)
107
- output_frame = denormalize(reconstruct_output)
108
-
109
- frames = []
110
- for frame_a, frame_b, frame_c in zip(input_frame, input_mask, output_frame):
111
- combined_frame = np.hstack([frame_a, frame_b, frame_c])
112
- frames.append(combined_frame)
113
-
114
- combined_gif = 'combined.gif'
115
- imageio.mimsave(combined_gif, frames, duration=300, loop=0)
116
- return confidences, combined_gif
117
 
118
 
119
  def main():
@@ -123,26 +68,14 @@ def main():
123
  gr.Video(type="file", label="Input Video"),
124
  gr.Dropdown(
125
  choices=ALL_MODELS,
126
- default="TFVideoMAE_L_K400_16x224",
127
  label="Model"
128
- ),
129
- gr.Slider(
130
- 0.5,
131
- 1.0,
132
- step=0.1,
133
- default=0.5,
134
- label='Mask Ratio'
135
  )
136
  ],
137
- outputs=[
138
- gr.Label(num_top_classes=3, label='scores'),
139
- gr.Image(type="filepath", label='reconstructed')
140
- ],
141
  examples=sample_example,
142
- title="VideoMAE",
143
- description="Keras reimplementation of <a href='https://github.com/innat/VideoMAE'>VideoMAE</a> is presented here."
144
  )
145
-
146
  iface.launch()
147
 
148
  if __name__ == '__main__':
 
6
  import tensorflow as tf
7
  from tensorflow import keras
8
 
9
+ from utils import read_video, frame_sampling
10
+ from utils import num_frames, patch_size, input_size
11
+ from labels import K400_label_map
 
12
 
13
 
14
  LABEL_MAPS = {
15
  'K400': K400_label_map,
 
 
16
  }
17
 
18
  ALL_MODELS = [
19
+ 'TFVideoFocalNetB_K400_8x224',
 
 
20
  ]
21
 
22
  sample_example = [
23
+ ["examples/k400.mp4", ALL_MODELS[0]],
 
 
24
  ]
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def get_model(model_type):
27
+ model_path = keras.utils.get_file(
28
+ origin=f'https://github.com/innat/Video-FocalNets/releases/download/v1.1/{model_type}.zip',
 
 
 
29
  )
30
+ with zipfile.ZipFile(model_path, 'r') as zip_ref:
 
31
  zip_ref.extractall('./')
32
 
33
+ model = keras.models.load_model(model_type)
 
 
 
 
34
 
35
  if 'K400' in model_type:
36
  data_type = 'K400'
 
 
37
  else:
38
+ data_type = 'SSv2'
39
 
40
  label_map = LABEL_MAPS.get(data_type)
41
  label_map = {v: k for k, v in label_map.items()}
42
 
43
+ return model, label_map
44
 
45
 
46
+ def inference(video_file, model_type):
47
  # get sample data
48
  container = read_video(video_file)
49
  frames = frame_sampling(container, num_frames=num_frames)
50
 
51
  # get models
52
+ model, label_map = get_model(model_type)
53
+ model.trainable = False
54
+
55
+ # inference on model
56
+ outputs = model(frames[None, ...], training=False)
57
+ probabilities = tf.nn.softmax(outputs).numpy().squeeze(0)
 
 
58
  confidences = {
59
  label_map[i]: float(probabilities[i]) for i in np.argsort(probabilities)[::-1]
60
  }
61
+ return confidences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  def main():
 
68
  gr.Video(type="file", label="Input Video"),
69
  gr.Dropdown(
70
  choices=ALL_MODELS,
 
71
  label="Model"
 
 
 
 
 
 
 
72
  )
73
  ],
74
+ outputs=gr.Label(num_top_classes=3, label='scores'),
 
 
 
75
  examples=sample_example,
76
+ title="Video-FocalNets: Spatio-Temporal Focal Modulation.",
77
+ description="Keras reimplementation of <a href='https://github.com/innat/Video-FocalNets'>Video-FocalNets</a> is presented here."
78
  )
 
79
  iface.launch()
80
 
81
  if __name__ == '__main__':