fffiloni commited on
Commit
da0d3cc
·
verified ·
1 Parent(s): 09db832

change caching method and logs subprocess infos

Browse files
Files changed (1) hide show
  1. app.py +48 -15
app.py CHANGED
@@ -9,25 +9,38 @@ from huggingface_hub import snapshot_download
9
 
10
  # Download All Required Models using `snapshot_download`
11
 
12
- def download_and_extract(repo_id, target_dir):
13
  """
14
- Downloads a model repo (cached) and copies its contents to a local target directory.
15
- If the target_dir exists, it will be updated (not re-downloaded if cache is present).
16
  """
17
- print(f"Downloading {repo_id} into cache...")
18
- snapshot_path = snapshot_download(repo_id)
19
-
20
- print(f"Copying files to {target_dir}...")
21
- os.makedirs(target_dir, exist_ok=True)
22
- shutil.copytree(snapshot_path, target_dir, dirs_exist_ok=True)
23
-
24
- print(f"Done: {repo_id} extracted to {target_dir}")
 
 
 
 
 
 
 
 
 
 
25
  return target_dir
26
 
27
 
28
- wan_model_path = download_and_extract("Wan-AI/Wan2.1-I2V-14B-480P", "./weights/Wan2.1-I2V-14B-480P")
29
- wav2vec_path = download_and_extract("TencentGameMate/chinese-wav2vec2-base", "./weights/chinese-wav2vec2-base")
30
- multitalk_path = download_and_extract("MeiGen-AI/MeiGen-MultiTalk", "./weights/MeiGen-MultiTalk")
 
 
 
 
31
 
32
 
33
  # Define paths
@@ -130,7 +143,27 @@ def infer(prompt, cond_image_path, cond_audio_path):
130
  "--save_file", "multi_long_mediumvram_exp"
131
  ]
132
 
133
- subprocess.run(cmd, check=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  return "multi_long_mediumvra_exp.mp4"
136
 
 
9
 
10
  # Download All Required Models using `snapshot_download`
11
 
12
+ def download_and_extract(repo_id, target_dir, cache_dir=None):
13
  """
14
+ Download from HF Hub to cache, then copy to target_dir if not already present.
 
15
  """
16
+ print(f"Downloading (with cache) {repo_id}...")
17
+
18
+ # Use snapshot_download with optional custom cache
19
+ snapshot_path = snapshot_download(
20
+ repo_id=repo_id,
21
+ cache_dir=cache_dir, # You can pass a shared cache path here
22
+ local_dir=None, # Ensure it's using the actual cache
23
+ local_dir_use_symlinks=False
24
+ )
25
+
26
+ # Copy from cache to target directory
27
+ if not os.path.exists(target_dir) or not os.listdir(target_dir):
28
+ os.makedirs(target_dir, exist_ok=True)
29
+ shutil.copytree(snapshot_path, target_dir, dirs_exist_ok=True)
30
+ print(f"Copied {repo_id} to {target_dir}")
31
+ else:
32
+ print(f"{target_dir} already populated. Skipping copy.")
33
+
34
  return target_dir
35
 
36
 
37
+ # Optional: share one cache location across all models
38
+ custom_cache = "./hf_cache"
39
+
40
+ wan_model_path = download_and_extract("Wan-AI/Wan2.1-I2V-14B-480P", "./weights/Wan2.1-I2V-14B-480P", cache_dir=custom_cache)
41
+ wav2vec_path = download_and_extract("TencentGameMate/chinese-wav2vec2-base", "./weights/chinese-wav2vec2-base", cache_dir=custom_cache)
42
+ multitalk_path = download_and_extract("MeiGen-AI/MeiGen-MultiTalk", "./weights/MeiGen-MultiTalk", cache_dir=custom_cache)
43
+
44
 
45
 
46
  # Define paths
 
143
  "--save_file", "multi_long_mediumvram_exp"
144
  ]
145
 
146
+ # Optional: log file
147
+ log_file_path = "inference.log"
148
+
149
+ # Run and stream logs in real-time
150
+ with open(log_file_path, "w") as log_file:
151
+ process = subprocess.Popen(
152
+ cmd,
153
+ stdout=subprocess.PIPE,
154
+ stderr=subprocess.STDOUT,
155
+ text=True,
156
+ bufsize=1 # Line-buffered
157
+ )
158
+
159
+ for line in process.stdout:
160
+ print(line, end="") # Print to console in real-time
161
+ log_file.write(line) # Save to log file
162
+
163
+ process.wait()
164
+
165
+ if process.returncode != 0:
166
+ raise RuntimeError("Inference failed. Check inference.log for details.")
167
 
168
  return "multi_long_mediumvra_exp.mp4"
169