Fly-ShuAI commited on
Commit
7b39fe7
·
verified ·
1 Parent(s): c0231fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -52
app.py CHANGED
@@ -14,25 +14,26 @@ from diffsynth import ModelManager, WanVideoPipeline, save_video
14
 
15
 
16
  num_frames, width, height = 49, 832, 480
17
- gpu_id = 0
18
- device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
 
19
  # pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
20
 
21
- # from modelscope import snapshot_download
22
- # model_dir = snapshot_download( # https://www.modelscope.cn/models/AI-ModelScope/RMBG-2.0
23
- # model_id = 'AI-ModelScope/RMBG-2.0',
24
- # local_dir = 'ckpt/RMBG-2.0',
25
- # ignore_file_pattern = ['onnx*'],
26
- # )
27
 
28
- # from huggingface_hub import snapshot_download, hf_hub_download
29
- # snapshot_download( # 下载整个仓库; 下briaai/RMBG-2.0需要token
30
- # repo_id="alibaba-pai/Wan2.1-Fun-1.3B-Control",
31
- # local_dir="ckpt/Wan2.1-Fun-1.3B-Control",
32
- # local_dir_use_symlinks=False,
33
- # resume_download=True,
34
- # repo_type="model"
35
- # )
36
 
37
  # hf_hub_download(
38
  # repo_id="Kunbyte/Lumen",
@@ -42,37 +43,37 @@ device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
42
  # resume_download=True,
43
  # )
44
 
45
- # rmbg_model = AutoModelForImageSegmentation.from_pretrained('ckpt/RMBG-2.0', trust_remote_code=True) # ckpt/RMBG-2.0
46
- # torch.set_float32_matmul_precision(['high', 'highest'][0])
47
- # rmbg_model.to(device)
48
- # rmbg_model.eval()
49
 
50
- # model_manager = ModelManager(device="cpu") # 1.3b: device=cpu: uses 6G VRAM, device=device: uses 16G VRAM; about 1-2 min per video
51
- # wan_dit_path = 'train_res/wan1.3b_zh/full_wc0.5_f1gt0.5_real1_2_zh_en_l_s/lightning_logs/version_0/checkpoints/step-step=30000.ckpt'
52
 
53
- # if 'wan14b' in wan_dit_path.lower(): # 14B: uses about 36G, about 10 min per video
54
- # model_manager.load_models(
55
- # [
56
- # wan_dit_path if wan_dit_path else 'ckpt/Wan2.1-Fun-14B-Control/diffusion_pytorch_model.safetensors',
57
- # 'ckpt/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth',
58
- # 'ckpt/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth',
59
- # 'ckpt/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth',
60
- # ],
61
- # torch_dtype=torch.bfloat16, # float8_e4m3fn fp8量化; bfloat16
62
- # )
63
- # else:
64
- # wan_dit_path = None
65
- # model_manager.load_models(
66
- # [
67
- # wan_dit_path if wan_dit_path else 'ckpt/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors',
68
- # 'ckpt/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth',
69
- # 'ckpt/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth',
70
- # 'ckpt/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth',
71
- # ],
72
- # torch_dtype=torch.bfloat16,
73
- # )
74
- # wan_pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device=device)
75
- # wan_pipe.enable_vram_management(num_persistent_param_in_dit=None)
76
 
77
  gr_info_duration = 2 # gradio popup information duration
78
 
@@ -196,7 +197,7 @@ video_dir = 'test/pachong_test/video/single'
196
  relight_dir = ''
197
 
198
  header = """
199
- # 💡Lumen: Consistent Video Relighting and Harmonious Background Replacement\n # <center>with Video Generative Models </center>
200
 
201
  <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
202
  <a href="https://lumen-relight.github.io"><img src="https://img.shields.io/badge/Project%20Page-Lumen-blue" alt="Project"></a>
@@ -324,9 +325,4 @@ with gr.Blocks(title="Lumen: Video Relighting Model").queue() as demo:
324
 
325
  # Launch application
326
  if __name__ == "__main__":
327
- demo.launch()
328
- # demo.launch(
329
- # server_name='0.0.0.0',
330
- # debug=True,
331
- # ssr_mode=False,
332
- # )
 
14
 
15
 
16
  num_frames, width, height = 49, 832, 480
17
+ # gpu_id = 3
18
+ # device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
19
+ device = f'cuda' if torch.cuda.is_available() else 'cpu'
20
  # pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
21
 
22
+ from modelscope import snapshot_download
23
+ model_dir = snapshot_download( # https://www.modelscope.cn/models/AI-ModelScope/RMBG-2.0
24
+ model_id = 'AI-ModelScope/RMBG-2.0',
25
+ local_dir = 'ckpt/RMBG-2.0',
26
+ ignore_file_pattern = ['onnx*'],
27
+ )
28
 
29
+ from huggingface_hub import snapshot_download, hf_hub_download
30
+ snapshot_download( # 下载整个仓库; 下briaai/RMBG-2.0需要token
31
+ repo_id="alibaba-pai/Wan2.1-Fun-1.3B-Control",
32
+ local_dir="ckpt/Wan2.1-Fun-1.3B-Control",
33
+ local_dir_use_symlinks=False,
34
+ resume_download=True,
35
+ repo_type="model"
36
+ )
37
 
38
  # hf_hub_download(
39
  # repo_id="Kunbyte/Lumen",
 
43
  # resume_download=True,
44
  # )
45
 
46
+ rmbg_model = AutoModelForImageSegmentation.from_pretrained('ckpt/RMBG-2.0', trust_remote_code=True) # ckpt/RMBG-2.0
47
+ torch.set_float32_matmul_precision(['high', 'highest'][0])
48
+ rmbg_model.to(device)
49
+ rmbg_model.eval()
50
 
51
+ model_manager = ModelManager(device="cpu") # 1.3b: device=cpu: uses 6G VRAM, device=device: uses 16G VRAM; about 1-2 min per video
52
+ wan_dit_path = 'train_res/wan1.3b_zh/full_wc0.5_f1gt0.5_real1_2_zh_en_l_s/lightning_logs/version_0/checkpoints/step-step=30000.ckpt'
53
 
54
+ if 'wan14b' in wan_dit_path.lower(): # 14B: uses about 36G, about 10 min per video
55
+ model_manager.load_models(
56
+ [
57
+ wan_dit_path if wan_dit_path else 'ckpt/Wan2.1-Fun-14B-Control/diffusion_pytorch_model.safetensors',
58
+ 'ckpt/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth',
59
+ 'ckpt/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth',
60
+ 'ckpt/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth',
61
+ ],
62
+ torch_dtype=torch.bfloat16, # float8_e4m3fn fp8量化; bfloat16
63
+ )
64
+ else:
65
+ wan_dit_path = None
66
+ model_manager.load_models(
67
+ [
68
+ wan_dit_path if wan_dit_path else 'ckpt/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors',
69
+ 'ckpt/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth',
70
+ 'ckpt/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth',
71
+ 'ckpt/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth',
72
+ ],
73
+ torch_dtype=torch.bfloat16,
74
+ )
75
+ wan_pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device=device)
76
+ wan_pipe.enable_vram_management(num_persistent_param_in_dit=None)
77
 
78
  gr_info_duration = 2 # gradio popup information duration
79
 
 
197
  relight_dir = ''
198
 
199
  header = """
200
+ # <center>💡Lumen: Consistent Video Relighting and Harmonious Background Replacement with Video Generative Models </center>
201
 
202
  <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
203
  <a href="https://lumen-relight.github.io"><img src="https://img.shields.io/badge/Project%20Page-Lumen-blue" alt="Project"></a>
 
325
 
326
  # Launch application
327
  if __name__ == "__main__":
328
+ demo.launch() # max_threads