import csv import os import tempfile import gradio as gr import requests import torch import torchvision import torchvision.transforms as T from PIL import Image from featup.util import norm from torchaudio.functional import resample from denseav.train import LitAVAligner from denseav.plotting import plot_attention_video, plot_2head_attention_video, plot_feature_video from denseav.shared import norm, crop_to_divisor, blur_dim from os.path import join if __name__ == "__main__": mode = "hf" if mode == "local": sample_videos_dir = "samples" else: os.environ['TORCH_HOME'] = '/tmp/.cache' os.environ['HF_HOME'] = '/tmp/.cache' os.environ['HF_DATASETS_CACHE'] = '/tmp/.cache' os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache' os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache' sample_videos_dir = "/tmp/samples" def download_video(url, save_path): response = requests.get(url) with open(save_path, 'wb') as file: file.write(response.content) base_url = "https://marhamilresearch4.blob.core.windows.net/denseav-public/samples/" sample_videos_urls = { "puppies.mp4": base_url + "puppies.mp4", "peppers.mp4": base_url + "peppers.mp4", "boat.mp4": base_url + "boat.mp4", "elephant2.mp4": base_url + "elephant2.mp4", } # Ensure the directory for sample videos exists os.makedirs(sample_videos_dir, exist_ok=True) # Download each sample video for filename, url in sample_videos_urls.items(): save_path = os.path.join(sample_videos_dir, filename) # Download the video if it doesn't already exist if not os.path.exists(save_path): print(f"Downloading {filename}...") download_video(url, save_path) else: print(f"{filename} already exists. Skipping download.") csv.field_size_limit(100000000) options = ['language', "sound-language", "sound"] load_size = 224 plot_size = 224 video_input = gr.Video(label="Choose a video to featurize", height=480) model_option = gr.Radio(options, value="language", label='Choose a model') video_output1 = gr.Video(label="Audio Video Attention", height=480) video_output2 = gr.Video(label="Multi-Head Audio Video Attention (Only Availible for sound_and_language)", height=480) video_output3 = gr.Video(label="Visual Features", height=480) models = {o: LitAVAligner.from_pretrained(f"mhamilton723/DenseAV-{o}") for o in options} def process_video(video, model_option): model = models[model_option].cuda() original_frames, audio, info = torchvision.io.read_video(video, end_pts=10, pts_unit='sec') sample_rate = 16000 if info["audio_fps"] != sample_rate: audio = resample(audio, info["audio_fps"], sample_rate) audio = audio[0].unsqueeze(0) img_transform = T.Compose([ T.Resize(load_size, Image.BILINEAR), lambda x: crop_to_divisor(x, 8), lambda x: x.to(torch.float32) / 255, norm]) frames = torch.cat([img_transform(f.permute(2, 0, 1)).unsqueeze(0) for f in original_frames], axis=0) plotting_img_transform = T.Compose([ T.Resize(plot_size, Image.BILINEAR), lambda x: crop_to_divisor(x, 8), lambda x: x.to(torch.float32) / 255]) frames_to_plot = plotting_img_transform(original_frames.permute(0, 3, 1, 2)) with torch.no_grad(): audio_feats = model.forward_audio({"audio": audio.cuda()}) audio_feats = {k: v.cpu() for k, v in audio_feats.items()} image_feats = model.forward_image({"frames": frames.unsqueeze(0).cuda()}, max_batch_size=2) image_feats = {k: v.cpu() for k, v in image_feats.items()} sim_by_head = model.sim_agg.get_pairwise_sims( {**image_feats, **audio_feats}, raw=False, agg_sim=False, agg_heads=False ).mean(dim=-2).cpu() sim_by_head = blur_dim(sim_by_head, window=3, dim=-1) print(sim_by_head.shape) temp_video_path_1 = tempfile.mktemp(suffix='.mp4') plot_attention_video( sim_by_head, frames_to_plot, audio, info["video_fps"], sample_rate, temp_video_path_1) if model_option == "sound_and_language": temp_video_path_2 = tempfile.mktemp(suffix='.mp4') plot_2head_attention_video( sim_by_head, frames_to_plot, audio, info["video_fps"], sample_rate, temp_video_path_2) else: temp_video_path_2 = None temp_video_path_3 = tempfile.mktemp(suffix='.mp4') temp_video_path_4 = tempfile.mktemp(suffix='.mp4') plot_feature_video( image_feats["image_feats"].cpu(), audio_feats['audio_feats'].cpu(), frames_to_plot, audio, info["video_fps"], sample_rate, temp_video_path_3, temp_video_path_4, ) # return temp_video_path_1, temp_video_path_2, temp_video_path_3, temp_video_path_4 return temp_video_path_1, temp_video_path_2, temp_video_path_3 with gr.Blocks() as demo: with gr.Column(): gr.Markdown("## Visualizing Sound and Language with DenseAV") gr.Markdown( "This demo allows you to explore the inner attention maps of DenseAV's dense multi-head contrastive operator.") with gr.Row(): with gr.Column(scale=1): model_option.render() with gr.Column(scale=3): video_input.render() with gr.Row(): submit_button = gr.Button("Submit") with gr.Row(): gr.Examples( examples=[ [join(sample_videos_dir, "puppies.mp4"), "sound_and_language"], [join(sample_videos_dir, "peppers.mp4"), "language"], [join(sample_videos_dir, "elephant2.mp4"), "language"], [join(sample_videos_dir, "boat.mp4"), "language"] ], inputs=[video_input, model_option] ) with gr.Row(): video_output1.render() video_output2.render() video_output3.render() submit_button.click(fn=process_video, inputs=[video_input, model_option], outputs=[video_output1, video_output2, video_output3]) if mode == "local": demo.launch(server_name="0.0.0.0", server_port=6006, debug=True) else: demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)