mhamilton723 commited on
Commit
c5d5ef0
·
verified ·
1 Parent(s): 523ffdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -98
app.py CHANGED
@@ -1,122 +1,165 @@
1
- import matplotlib.pyplot as plt
 
 
 
 
 
2
  import torch
 
3
  import torchvision.transforms as T
4
  from PIL import Image
5
- import gradio as gr
6
- from featup.util import norm, unnorm, pca, remove_axes
7
- from pytorch_lightning import seed_everything
8
- import os
9
- import requests
10
- import os
11
- import csv
12
 
13
- def plot_feats(image, lr, hr):
14
- assert len(image.shape) == len(lr.shape) == len(hr.shape) == 3
15
- seed_everything(0)
16
- [lr_feats_pca, hr_feats_pca], _ = pca([lr.unsqueeze(0), hr.unsqueeze(0)], dim=9)
17
- fig, ax = plt.subplots(3, 3, figsize=(15, 15))
18
- ax[0, 0].imshow(image.permute(1, 2, 0).detach().cpu())
19
- ax[1, 0].imshow(image.permute(1, 2, 0).detach().cpu())
20
- ax[2, 0].imshow(image.permute(1, 2, 0).detach().cpu())
21
 
22
- ax[0, 0].set_title("Image", fontsize=22)
23
- ax[0, 1].set_title("Original", fontsize=22)
24
- ax[0, 2].set_title("Upsampled Features", fontsize=22)
25
 
26
- ax[0, 1].imshow(lr_feats_pca[0, :3].permute(1, 2, 0).detach().cpu())
27
- ax[0, 0].set_ylabel("PCA Components 1-3", fontsize=22)
28
- ax[0, 2].imshow(hr_feats_pca[0, :3].permute(1, 2, 0).detach().cpu())
29
-
30
- ax[1, 1].imshow(lr_feats_pca[0, 3:6].permute(1, 2, 0).detach().cpu())
31
- ax[1, 0].set_ylabel("PCA Components 4-6", fontsize=22)
32
- ax[1, 2].imshow(hr_feats_pca[0, 3:6].permute(1, 2, 0).detach().cpu())
33
-
34
- ax[2, 1].imshow(lr_feats_pca[0, 6:9].permute(1, 2, 0).detach().cpu())
35
- ax[2, 0].set_ylabel("PCA Components 7-9", fontsize=22)
36
- ax[2, 2].imshow(hr_feats_pca[0, 6:9].permute(1, 2, 0).detach().cpu())
37
-
38
- remove_axes(ax)
39
- plt.tight_layout()
40
- plt.close(fig) # Close plt to avoid additional empty plots
41
- return fig
42
 
 
 
 
 
43
 
44
- if __name__ == "__main__":
45
 
46
- def download_image(url, save_path):
47
  response = requests.get(url)
48
  with open(save_path, 'wb') as file:
49
  file.write(response.content)
50
 
51
- base_url = "https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/sample_images/"
52
- sample_images_urls = {
53
- "skate.jpg": base_url + "skate.jpg",
54
- "car.jpg": base_url + "car.jpg",
55
- "plant.png": base_url + "plant.png",
56
- }
57
 
58
- sample_images_dir = "/tmp/sample_images"
 
 
 
59
 
60
- # Ensure the directory for sample images exists
61
- os.makedirs(sample_images_dir, exist_ok=True)
62
 
63
- # Download each sample image
64
- for filename, url in sample_images_urls.items():
65
- save_path = os.path.join(sample_images_dir, filename)
66
- # Download the image if it doesn't already exist
67
  if not os.path.exists(save_path):
68
  print(f"Downloading {filename}...")
69
- download_image(url, save_path)
70
  else:
71
  print(f"{filename} already exists. Skipping download.")
72
 
73
- os.environ['TORCH_HOME'] = '/tmp/.cache'
74
- os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
75
  csv.field_size_limit(100000000)
76
- options = ['dino16', 'vit', 'dinov2', 'clip', 'resnet50']
77
-
78
- image_input = gr.Image(label="Choose an image to featurize",
79
- height=480,
80
- type="pil",
81
- image_mode='RGB',
82
- sources=['upload', 'webcam', 'clipboard']
83
- )
84
- model_option = gr.Radio(options, value="dino16", label='Choose a backbone to upsample')
85
-
86
- models = {o: torch.hub.load("mhamilton723/FeatUp", o) for o in options}
87
-
88
-
89
- def upsample_features(image, model_option):
90
- # Image preprocessing
91
- input_size = 224
92
- transform = T.Compose([
93
- T.Resize(input_size),
94
- T.CenterCrop((input_size, input_size)),
95
- T.ToTensor(),
96
- norm
97
- ])
98
- image_tensor = transform(image).unsqueeze(0).cuda()
99
-
100
- # Load the selected model
101
- upsampler = models[model_option].cuda()
102
- hr_feats = upsampler(image_tensor)
103
- lr_feats = upsampler.model(image_tensor)
104
- upsampler.cpu()
105
-
106
- return plot_feats(unnorm(image_tensor)[0], lr_feats[0], hr_feats[0])
107
-
108
-
109
- demo = gr.Interface(fn=upsample_features,
110
- inputs=[image_input, model_option],
111
- outputs="plot",
112
- title="Feature Upsampling Demo",
113
- description="This demo allows you to upsample features of an image using selected models.",
114
- examples=[
115
- ["/tmp/sample_images/skate.jpg", "dino16"],
116
- ["/tmp/sample_images/car.jpg", "dinov2"],
117
- ["/tmp/sample_images/plant.png", "dino16"],
118
- ]
119
-
120
- )
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
 
1
+ import csv
2
+ import os
3
+ import tempfile
4
+
5
+ import gradio as gr
6
+ import requests
7
  import torch
8
+ import torchvision
9
  import torchvision.transforms as T
10
  from PIL import Image
11
+ from featup.util import norm
12
+ from torchaudio.functional import resample
 
 
 
 
 
13
 
14
+ from denseav.plotting import plot_attention_video, plot_2head_attention_video, plot_feature_video
15
+ from denseav.shared import norm, crop_to_divisor, blur_dim
16
+ from os.path import join
 
 
 
 
 
17
 
 
 
 
18
 
19
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ os.environ['TORCH_HOME'] = '/tmp/.cache'
22
+ os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
23
+ sample_images_dir = "/tmp/samples"
24
+ # sample_videos_dir = "samples"
25
 
 
26
 
27
+ def download_video(url, save_path):
28
  response = requests.get(url)
29
  with open(save_path, 'wb') as file:
30
  file.write(response.content)
31
 
 
 
 
 
 
 
32
 
33
+ base_url = "https://marhamilresearch4.blob.core.windows.net/denseav-public/samples/"
34
+ sample_videos_urls = {
35
+ "puppies.mp4": base_url + "puppies.mp4",
36
+ }
37
 
38
+ # Ensure the directory for sample videos exists
39
+ os.makedirs(sample_videos_dir, exist_ok=True)
40
 
41
+ # Download each sample video
42
+ for filename, url in sample_videos_urls.items():
43
+ save_path = os.path.join(sample_videos_dir, filename)
44
+ # Download the video if it doesn't already exist
45
  if not os.path.exists(save_path):
46
  print(f"Downloading {filename}...")
47
+ download_video(url, save_path)
48
  else:
49
  print(f"{filename} already exists. Skipping download.")
50
 
 
 
51
  csv.field_size_limit(100000000)
52
+ options = ['language', "sound", "sound_and_language"]
53
+ load_size = 224
54
+ plot_size = 224
55
+
56
+ video_input = gr.Video(label="Choose a video to featurize", height=480)
57
+ model_option = gr.Radio(options, value="language", label='Choose a model')
58
+
59
+ video_output1 = gr.Video(label="Audio Video Attention", height=480)
60
+ video_output2 = gr.Video(label="Multi-Head Audio Video Attention (Only Availible for sound_and_language)",
61
+ height=480)
62
+ video_output3 = gr.Video(label="Visual Features", height=480)
63
+ video_output4 = gr.Video(label="Audio Features", height=480)
64
+
65
+ models = {o: torch.hub.load("mhamilton723/DenseAV", o) for o in options}
66
+
67
+
68
+ def process_video(video, model_option):
69
+ model = models[model_option].cuda()
70
+
71
+ original_frames, audio, info = torchvision.io.read_video(video, end_pts=10, pts_unit='sec')
72
+ sample_rate = 16000
73
+
74
+ if info["audio_fps"] != sample_rate:
75
+ audio = resample(audio, info["audio_fps"], sample_rate)
76
+ audio = audio[0].unsqueeze(0)
77
+
78
+ img_transform = T.Compose([
79
+ T.Resize(load_size, Image.BILINEAR),
80
+ lambda x: crop_to_divisor(x, 8),
81
+ lambda x: x.to(torch.float32) / 255,
82
+ norm])
83
+
84
+ frames = torch.cat([img_transform(f.permute(2, 0, 1)).unsqueeze(0) for f in original_frames], axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ plotting_img_transform = T.Compose([
87
+ T.Resize(plot_size, Image.BILINEAR),
88
+ lambda x: crop_to_divisor(x, 8),
89
+ lambda x: x.to(torch.float32) / 255])
90
+
91
+ frames_to_plot = plotting_img_transform(original_frames.permute(0, 3, 1, 2))
92
+
93
+ with torch.no_grad():
94
+ audio_feats = model.forward_audio({"audio": audio.cuda()})
95
+ audio_feats = {k: v.cpu() for k, v in audio_feats.items()}
96
+ image_feats = model.forward_image({"frames": frames.unsqueeze(0).cuda()}, max_batch_size=2)
97
+ image_feats = {k: v.cpu() for k, v in image_feats.items()}
98
+
99
+ sim_by_head = model.sim_agg.get_pairwise_sims(
100
+ {**image_feats, **audio_feats},
101
+ raw=False,
102
+ agg_sim=False,
103
+ agg_heads=False
104
+ ).mean(dim=-2).cpu()
105
+
106
+ sim_by_head = blur_dim(sim_by_head, window=3, dim=-1)
107
+ print(sim_by_head.shape)
108
+
109
+ temp_video_path_1 = tempfile.mktemp(suffix='.mp4')
110
+
111
+ plot_attention_video(
112
+ sim_by_head,
113
+ frames_to_plot,
114
+ audio,
115
+ info["video_fps"],
116
+ sample_rate,
117
+ temp_video_path_1)
118
+
119
+ if model_option == "sound_and_language":
120
+ temp_video_path_2 = tempfile.mktemp(suffix='.mp4')
121
+
122
+ plot_2head_attention_video(
123
+ sim_by_head,
124
+ frames_to_plot,
125
+ audio,
126
+ info["video_fps"],
127
+ sample_rate,
128
+ temp_video_path_2)
129
+
130
+ else:
131
+ temp_video_path_2 = None
132
+
133
+ temp_video_path_3 = tempfile.mktemp(suffix='.mp4')
134
+ temp_video_path_4 = tempfile.mktemp(suffix='.mp4')
135
+
136
+ plot_feature_video(
137
+ image_feats["image_feats"].cpu(),
138
+ audio_feats['audio_feats'].cpu(),
139
+ frames_to_plot,
140
+ audio,
141
+ info["video_fps"],
142
+ sample_rate,
143
+ temp_video_path_3,
144
+ temp_video_path_4,
145
+ )
146
+ return temp_video_path_1, temp_video_path_2, temp_video_path_3, temp_video_path_4
147
+
148
+
149
+ with gr.Blocks() as demo:
150
+ with gr.Column():
151
+ video_input.render()
152
+ model_option.render()
153
+ with gr.Row():
154
+ video_output1.render()
155
+ video_output2.render()
156
+ with gr.Row():
157
+ video_output3.render()
158
+ video_output4.render()
159
+
160
+ demo.examples = [
161
+ [join(sample_videos_dir, "puppies.mp4"), "language"],
162
+ ]
163
+
164
+ # demo.launch(server_name="0.0.0.0", server_port=6006, debug=True)
165
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)