mhamilton723 commited on
Commit
bb26736
·
verified ·
1 Parent(s): 1b98c12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -11,21 +11,19 @@ from PIL import Image
11
  from featup.util import norm
12
  from torchaudio.functional import resample
13
 
 
14
  from denseav.plotting import plot_attention_video, plot_2head_attention_video, plot_feature_video
15
  from denseav.shared import norm, crop_to_divisor, blur_dim
16
  from os.path import join
17
 
18
  if __name__ == "__main__":
19
 
20
- mode = "hf2"
21
 
22
  if mode == "local":
23
  sample_videos_dir = "samples"
24
  else:
25
  os.environ['TORCH_HOME'] = '/tmp/.cache'
26
- os.environ['HF_HOME'] = '/tmp/.cache'
27
- os.environ['HF_DATASETS_CACHE'] = '/tmp/.cache'
28
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache'
29
  os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
30
  sample_videos_dir = "/tmp/samples"
31
 
@@ -59,7 +57,7 @@ if __name__ == "__main__":
59
  print(f"{filename} already exists. Skipping download.")
60
 
61
  csv.field_size_limit(100000000)
62
- options = ['language', "sound_and_language", "sound"]
63
  load_size = 224
64
  plot_size = 224
65
 
@@ -71,7 +69,7 @@ if __name__ == "__main__":
71
  height=480)
72
  video_output3 = gr.Video(label="Visual Features", height=480)
73
 
74
- models = {o: torch.hub.load("mhamilton723/DenseAV", o) for o in options}
75
 
76
 
77
  def process_video(video, model_option):
 
11
  from featup.util import norm
12
  from torchaudio.functional import resample
13
 
14
+ from denseav.train import LitAVAligner
15
  from denseav.plotting import plot_attention_video, plot_2head_attention_video, plot_feature_video
16
  from denseav.shared import norm, crop_to_divisor, blur_dim
17
  from os.path import join
18
 
19
  if __name__ == "__main__":
20
 
21
+ mode = "hf"
22
 
23
  if mode == "local":
24
  sample_videos_dir = "samples"
25
  else:
26
  os.environ['TORCH_HOME'] = '/tmp/.cache'
 
 
 
27
  os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
28
  sample_videos_dir = "/tmp/samples"
29
 
 
57
  print(f"{filename} already exists. Skipping download.")
58
 
59
  csv.field_size_limit(100000000)
60
+ options = ['language', "sound-language", "sound"]
61
  load_size = 224
62
  plot_size = 224
63
 
 
69
  height=480)
70
  video_output3 = gr.Video(label="Visual Features", height=480)
71
 
72
+ models = {o: LitAVAligner.from_pretrained(f"mhamilton723/DenseAV-{o}") for o in options}
73
 
74
 
75
  def process_video(video, model_option):