lorocksUMD commited on
Commit
14f8d04
·
verified ·
1 Parent(s): e12b827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -148
app.py CHANGED
@@ -16,184 +16,183 @@ from denseav.plotting import plot_attention_video, plot_2head_attention_video, p
16
  from denseav.shared import norm, crop_to_divisor, blur_dim
17
  from os.path import join
18
 
19
- if __name__ == "__main__":
20
 
21
- mode = "hf"
22
 
23
- if mode == "local":
24
- sample_videos_dir = "samples"
25
- else:
26
- os.environ['TORCH_HOME'] = '/tmp/.cache'
27
- os.environ['HF_HOME'] = '/tmp/.cache'
28
- os.environ['HF_DATASETS_CACHE'] = '/tmp/.cache'
29
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache'
30
- os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
31
- sample_videos_dir = "/tmp/samples"
32
-
33
-
34
- def download_video(url, save_path):
35
- response = requests.get(url)
36
- with open(save_path, 'wb') as file:
37
- file.write(response.content)
38
 
39
 
40
- base_url = "https://marhamilresearch4.blob.core.windows.net/denseav-public/samples/"
41
- sample_videos_urls = {
42
- "puppies.mp4": base_url + "puppies.mp4",
43
- "peppers.mp4": base_url + "peppers.mp4",
44
- "boat.mp4": base_url + "boat.mp4",
45
- "elephant2.mp4": base_url + "elephant2.mp4",
46
 
47
- }
48
 
49
- # Ensure the directory for sample videos exists
50
- os.makedirs(sample_videos_dir, exist_ok=True)
 
 
 
 
51
 
52
- # Download each sample video
53
- for filename, url in sample_videos_urls.items():
54
- save_path = os.path.join(sample_videos_dir, filename)
55
- # Download the video if it doesn't already exist
56
- if not os.path.exists(save_path):
57
- print(f"Downloading {filename}...")
58
- download_video(url, save_path)
59
- else:
60
- print(f"{filename} already exists. Skipping download.")
61
 
62
- csv.field_size_limit(100000000)
63
- options = ['language', "sound-language", "sound"]
64
- load_size = 224
65
- plot_size = 224
66
 
67
- video_input = gr.Video(label="Choose a video to featurize", height=480)
68
- model_option = gr.Radio(options, value="language", label='Choose a model')
69
-
70
- video_output1 = gr.Video(label="Audio Video Attention", height=480)
71
- video_output2 = gr.Video(label="Multi-Head Audio Video Attention (Only Availible for sound_and_language)",
72
- height=480)
73
- video_output3 = gr.Video(label="Visual Features", height=480)
 
 
74
 
75
- models = {o: LitAVAligner.from_pretrained(f"mhamilton723/DenseAV-{o}") for o in options}
 
 
 
76
 
 
 
77
 
78
- def process_video(video, model_option):
79
- # model = models[model_option].cuda()
80
- model = models[model_option]
 
81
 
82
- original_frames, audio, info = torchvision.io.read_video(video, end_pts=10, pts_unit='sec')
83
- sample_rate = 16000
84
 
85
- if info["audio_fps"] != sample_rate:
86
- audio = resample(audio, info["audio_fps"], sample_rate)
87
- audio = audio[0].unsqueeze(0)
88
 
89
- img_transform = T.Compose([
90
- T.Resize(load_size, Image.BILINEAR),
91
- lambda x: crop_to_divisor(x, 8),
92
- lambda x: x.to(torch.float32) / 255,
93
- norm])
94
 
95
- frames = torch.cat([img_transform(f.permute(2, 0, 1)).unsqueeze(0) for f in original_frames], axis=0)
 
96
 
97
- plotting_img_transform = T.Compose([
98
- T.Resize(plot_size, Image.BILINEAR),
99
- lambda x: crop_to_divisor(x, 8),
100
- lambda x: x.to(torch.float32) / 255])
101
 
102
- frames_to_plot = plotting_img_transform(original_frames.permute(0, 3, 1, 2))
 
 
 
 
103
 
104
- with torch.no_grad():
105
- # audio_feats = model.forward_audio({"audio": audio.cuda()})
106
- audio_feats = model.forward_audio({"audio": audio})
107
- audio_feats = {k: v.cpu() for k, v in audio_feats.items()}
108
- # image_feats = model.forward_image({"frames": frames.unsqueeze(0).cuda()}, max_batch_size=2)
109
- image_feats = model.forward_image({"frames": frames.unsqueeze(0)}, max_batch_size=2)
110
- image_feats = {k: v.cpu() for k, v in image_feats.items()}
111
 
112
- sim_by_head = model.sim_agg.get_pairwise_sims(
113
- {**image_feats, **audio_feats},
114
- raw=False,
115
- agg_sim=False,
116
- agg_heads=False
117
- ).mean(dim=-2).cpu()
118
 
119
- sim_by_head = blur_dim(sim_by_head, window=3, dim=-1)
120
- print(sim_by_head.shape)
121
 
122
- temp_video_path_1 = tempfile.mktemp(suffix='.mp4')
 
 
 
 
 
 
123
 
124
- plot_attention_video(
125
- sim_by_head,
126
- frames_to_plot,
127
- audio,
128
- info["video_fps"],
129
- sample_rate,
130
- temp_video_path_1)
131
 
132
- if model_option == "sound_and_language":
133
- temp_video_path_2 = tempfile.mktemp(suffix='.mp4')
134
 
135
- plot_2head_attention_video(
136
- sim_by_head,
137
- frames_to_plot,
138
- audio,
139
- info["video_fps"],
140
- sample_rate,
141
- temp_video_path_2)
142
 
143
- else:
144
- temp_video_path_2 = None
 
 
 
 
 
145
 
146
- temp_video_path_3 = tempfile.mktemp(suffix='.mp4')
147
- temp_video_path_4 = tempfile.mktemp(suffix='.mp4')
148
 
149
- plot_feature_video(
150
- image_feats["image_feats"].cpu(),
151
- audio_feats['audio_feats'].cpu(),
152
  frames_to_plot,
153
  audio,
154
  info["video_fps"],
155
  sample_rate,
156
- temp_video_path_3,
157
- temp_video_path_4,
158
- )
159
- # return temp_video_path_1, temp_video_path_2, temp_video_path_3, temp_video_path_4
160
-
161
- return temp_video_path_1, temp_video_path_2, temp_video_path_3
162
-
163
-
164
- with gr.Blocks() as demo:
165
- with gr.Column():
166
- gr.Markdown("## Visualizing Sound and Language with DenseAV")
167
- gr.Markdown(
168
- "This demo allows you to explore the inner attention maps of DenseAV's dense multi-head contrastive operator.")
169
- with gr.Row():
170
- with gr.Column(scale=1):
171
- model_option.render()
172
- with gr.Column(scale=3):
173
- video_input.render()
174
- with gr.Row():
175
- submit_button = gr.Button("Submit")
176
- with gr.Row():
177
- gr.Examples(
178
- examples=[
179
- [join(sample_videos_dir, "puppies.mp4"), "sound_and_language"],
180
- [join(sample_videos_dir, "peppers.mp4"), "language"],
181
- [join(sample_videos_dir, "elephant2.mp4"), "language"],
182
- [join(sample_videos_dir, "boat.mp4"), "language"]
183
-
184
- ],
185
- inputs=[video_input, model_option]
186
- )
187
- with gr.Row():
188
- video_output1.render()
189
- video_output2.render()
190
- video_output3.render()
191
-
192
- submit_button.click(fn=process_video, inputs=[video_input, model_option],
193
- outputs=[video_output1, video_output2, video_output3])
194
-
195
-
196
- if mode == "local":
197
- demo.launch(server_name="0.0.0.0", server_port=6006, debug=True)
198
  else:
199
- demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from denseav.shared import norm, crop_to_divisor, blur_dim
17
  from os.path import join
18
 
 
19
 
20
+ mode = "hf"
21
 
22
+ if mode == "local":
23
+ sample_videos_dir = "samples"
24
+ else:
25
+ os.environ['TORCH_HOME'] = '/tmp/.cache'
26
+ os.environ['HF_HOME'] = '/tmp/.cache'
27
+ os.environ['HF_DATASETS_CACHE'] = '/tmp/.cache'
28
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache'
29
+ os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
30
+ sample_videos_dir = "/tmp/samples"
 
 
 
 
 
 
31
 
32
 
33
+ def download_video(url, save_path):
34
+ response = requests.get(url)
35
+ with open(save_path, 'wb') as file:
36
+ file.write(response.content)
 
 
37
 
 
38
 
39
+ base_url = "https://marhamilresearch4.blob.core.windows.net/denseav-public/samples/"
40
+ sample_videos_urls = {
41
+ "puppies.mp4": base_url + "puppies.mp4",
42
+ "peppers.mp4": base_url + "peppers.mp4",
43
+ "boat.mp4": base_url + "boat.mp4",
44
+ "elephant2.mp4": base_url + "elephant2.mp4",
45
 
46
+ }
 
 
 
 
 
 
 
 
47
 
48
+ # Ensure the directory for sample videos exists
49
+ os.makedirs(sample_videos_dir, exist_ok=True)
 
 
50
 
51
+ # Download each sample video
52
+ for filename, url in sample_videos_urls.items():
53
+ save_path = os.path.join(sample_videos_dir, filename)
54
+ # Download the video if it doesn't already exist
55
+ if not os.path.exists(save_path):
56
+ print(f"Downloading {filename}...")
57
+ download_video(url, save_path)
58
+ else:
59
+ print(f"{filename} already exists. Skipping download.")
60
 
61
+ csv.field_size_limit(100000000)
62
+ options = ['language', "sound-language", "sound"]
63
+ load_size = 224
64
+ plot_size = 224
65
 
66
+ video_input = gr.Video(label="Choose a video to featurize", height=480)
67
+ model_option = gr.Radio(options, value="language", label='Choose a model')
68
 
69
+ video_output1 = gr.Video(label="Audio Video Attention", height=480)
70
+ video_output2 = gr.Video(label="Multi-Head Audio Video Attention (Only Availible for sound_and_language)",
71
+ height=480)
72
+ video_output3 = gr.Video(label="Visual Features", height=480)
73
 
74
+ models = {o: LitAVAligner.from_pretrained(f"mhamilton723/DenseAV-{o}") for o in options}
 
75
 
 
 
 
76
 
77
+ def process_video(video, model_option):
78
+ # model = models[model_option].cuda()
79
+ model = models[model_option]
 
 
80
 
81
+ original_frames, audio, info = torchvision.io.read_video(video, end_pts=10, pts_unit='sec')
82
+ sample_rate = 16000
83
 
84
+ if info["audio_fps"] != sample_rate:
85
+ audio = resample(audio, info["audio_fps"], sample_rate)
86
+ audio = audio[0].unsqueeze(0)
 
87
 
88
+ img_transform = T.Compose([
89
+ T.Resize(load_size, Image.BILINEAR),
90
+ lambda x: crop_to_divisor(x, 8),
91
+ lambda x: x.to(torch.float32) / 255,
92
+ norm])
93
 
94
+ frames = torch.cat([img_transform(f.permute(2, 0, 1)).unsqueeze(0) for f in original_frames], axis=0)
 
 
 
 
 
 
95
 
96
+ plotting_img_transform = T.Compose([
97
+ T.Resize(plot_size, Image.BILINEAR),
98
+ lambda x: crop_to_divisor(x, 8),
99
+ lambda x: x.to(torch.float32) / 255])
 
 
100
 
101
+ frames_to_plot = plotting_img_transform(original_frames.permute(0, 3, 1, 2))
 
102
 
103
+ with torch.no_grad():
104
+ # audio_feats = model.forward_audio({"audio": audio.cuda()})
105
+ audio_feats = model.forward_audio({"audio": audio})
106
+ audio_feats = {k: v.cpu() for k, v in audio_feats.items()}
107
+ # image_feats = model.forward_image({"frames": frames.unsqueeze(0).cuda()}, max_batch_size=2)
108
+ image_feats = model.forward_image({"frames": frames.unsqueeze(0)}, max_batch_size=2)
109
+ image_feats = {k: v.cpu() for k, v in image_feats.items()}
110
 
111
+ sim_by_head = model.sim_agg.get_pairwise_sims(
112
+ {**image_feats, **audio_feats},
113
+ raw=False,
114
+ agg_sim=False,
115
+ agg_heads=False
116
+ ).mean(dim=-2).cpu()
 
117
 
118
+ sim_by_head = blur_dim(sim_by_head, window=3, dim=-1)
119
+ print(sim_by_head.shape)
120
 
121
+ temp_video_path_1 = tempfile.mktemp(suffix='.mp4')
 
 
 
 
 
 
122
 
123
+ plot_attention_video(
124
+ sim_by_head,
125
+ frames_to_plot,
126
+ audio,
127
+ info["video_fps"],
128
+ sample_rate,
129
+ temp_video_path_1)
130
 
131
+ if model_option == "sound_and_language":
132
+ temp_video_path_2 = tempfile.mktemp(suffix='.mp4')
133
 
134
+ plot_2head_attention_video(
135
+ sim_by_head,
 
136
  frames_to_plot,
137
  audio,
138
  info["video_fps"],
139
  sample_rate,
140
+ temp_video_path_2)
141
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  else:
143
+ temp_video_path_2 = None
144
+
145
+ temp_video_path_3 = tempfile.mktemp(suffix='.mp4')
146
+ temp_video_path_4 = tempfile.mktemp(suffix='.mp4')
147
+
148
+ plot_feature_video(
149
+ image_feats["image_feats"].cpu(),
150
+ audio_feats['audio_feats'].cpu(),
151
+ frames_to_plot,
152
+ audio,
153
+ info["video_fps"],
154
+ sample_rate,
155
+ temp_video_path_3,
156
+ temp_video_path_4,
157
+ )
158
+ # return temp_video_path_1, temp_video_path_2, temp_video_path_3, temp_video_path_4
159
+
160
+ return temp_video_path_1, temp_video_path_2, temp_video_path_3
161
+
162
+
163
+ with gr.Blocks() as demo:
164
+ with gr.Column():
165
+ gr.Markdown("## Visualizing Sound and Language with DenseAV")
166
+ gr.Markdown(
167
+ "This demo allows you to explore the inner attention maps of DenseAV's dense multi-head contrastive operator.")
168
+ with gr.Row():
169
+ with gr.Column(scale=1):
170
+ model_option.render()
171
+ with gr.Column(scale=3):
172
+ video_input.render()
173
+ with gr.Row():
174
+ submit_button = gr.Button("Submit")
175
+ with gr.Row():
176
+ gr.Examples(
177
+ examples=[
178
+ [join(sample_videos_dir, "puppies.mp4"), "sound_and_language"],
179
+ [join(sample_videos_dir, "peppers.mp4"), "language"],
180
+ [join(sample_videos_dir, "elephant2.mp4"), "language"],
181
+ [join(sample_videos_dir, "boat.mp4"), "language"]
182
+
183
+ ],
184
+ inputs=[video_input, model_option]
185
+ )
186
+ with gr.Row():
187
+ video_output1.render()
188
+ video_output2.render()
189
+ video_output3.render()
190
+
191
+ submit_button.click(fn=process_video, inputs=[video_input, model_option],
192
+ outputs=[video_output1, video_output2, video_output3])
193
+
194
+
195
+ if mode == "local":
196
+ demo.launch(server_name="0.0.0.0", server_port=6006, debug=True)
197
+ else:
198
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)