fffiloni commited on
Commit
5694315
1 Parent(s): 374b969

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -40
app.py CHANGED
@@ -14,6 +14,7 @@ from huggingface_hub import hf_hub_download
14
  # Ensure 'checkpoint' directory exists
15
  os.makedirs("checkpoints", exist_ok=True)
16
 
 
17
  hf_hub_download(
18
  repo_id="wenqsun/DimensionX",
19
  filename="orbit_left_lora_weights.safetensors",
@@ -26,93 +27,78 @@ hf_hub_download(
26
  local_dir="checkpoints"
27
  )
28
 
 
29
  model_id = "THUDM/CogVideoX-5b-I2V"
30
- transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16)
31
- text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float16)
32
- vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16)
33
  tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
34
-
35
 
36
  def find_and_move_object_to_cpu():
37
  for obj in gc.get_objects():
38
  try:
39
- # Check if the object is a PyTorch model
40
  if isinstance(obj, torch.nn.Module):
41
- # Check if any parameter of the model is on CUDA
42
  if any(param.is_cuda for param in obj.parameters()):
43
- print(f"Found PyTorch model on CUDA: {type(obj).__name__}")
44
- # Move the model to CPU
45
  obj.to('cpu')
46
- print(f"Moved {type(obj).__name__} to CPU.")
47
-
48
- # Optionally check if buffers are on CUDA
49
  if any(buf.is_cuda for buf in obj.buffers()):
50
- print(f"Found buffer on CUDA in {type(obj).__name__}")
51
  obj.to('cpu')
52
- print(f"Moved buffers of {type(obj).__name__} to CPU.")
53
-
54
  except Exception as e:
55
- # Handle any exceptions if obj is not a torch model
56
  pass
57
 
58
-
59
  def clear_gpu():
60
- """Clear GPU memory by removing tensors, freeing cache, and moving data to CPU."""
61
- # List memory usage before clearing
62
- print(f"Memory allocated before clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
63
- print(f"Memory reserved before clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
64
-
65
- # Move any bound tensors back to CPU if needed
66
- if torch.cuda.is_available():
67
- torch.cuda.empty_cache()
68
- torch.cuda.synchronize() # Ensure that all operations are completed
69
- print("GPU memory cleared.")
70
-
71
- print(f"Memory allocated after clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
72
- print(f"Memory reserved after clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
73
-
74
 
75
  def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
76
 
77
- pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.float16)
78
  lora_path = "checkpoints/"
79
- adapter_name = None
80
  if orbit_type == "Left":
81
  weight_name = "orbit_left_lora_weights.safetensors"
82
  elif orbit_type == "Up":
83
  weight_name = "orbit_up_lora_weights.safetensors"
84
  lora_rank = 256
85
 
 
 
86
  # Generate a timestamp for adapter_name
87
  adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
88
 
 
89
  pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"{adapter_timestamp}")
90
  pipe.fuse_lora(lora_scale=1 / lora_rank)
 
 
91
  pipe.to("cuda")
92
-
93
-
94
  prompt = f"{prompt}. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
95
  image = load_image(image_path)
96
  seed = random.randint(0, 2**8 - 1)
97
 
 
98
  video = pipe(
99
  image,
100
  prompt,
101
- num_inference_steps=50, # NOT Changed
102
- guidance_scale=7.0, # NOT Changed
103
  use_dynamic_cfg=True,
104
  generator=torch.Generator(device="cpu").manual_seed(seed)
105
  )
106
-
107
- find_and_move_object_to_cpu()
108
- clear_gpu()
109
 
110
- # Generate a timestamp for the output filename
111
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
112
  export_to_video(video.frames[0], f"output_{timestamp}.mp4", fps=8)
 
 
 
 
113
 
114
  return f"output_{timestamp}.mp4"
115
 
 
116
  with gr.Blocks(analytics_enabled=False) as demo:
117
  with gr.Column(elem_id="col-container"):
118
  gr.Markdown("# DimensionX")
 
14
  # Ensure 'checkpoint' directory exists
15
  os.makedirs("checkpoints", exist_ok=True)
16
 
17
+ # Download LoRA weights
18
  hf_hub_download(
19
  repo_id="wenqsun/DimensionX",
20
  filename="orbit_left_lora_weights.safetensors",
 
27
  local_dir="checkpoints"
28
  )
29
 
30
+ # Load models in the global scope
31
  model_id = "THUDM/CogVideoX-5b-I2V"
32
+ transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16).to("cpu")
33
+ text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float16).to("cpu")
34
+ vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16).to("cpu")
35
  tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
36
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.float16)
37
 
38
  def find_and_move_object_to_cpu():
39
  for obj in gc.get_objects():
40
  try:
 
41
  if isinstance(obj, torch.nn.Module):
 
42
  if any(param.is_cuda for param in obj.parameters()):
 
 
43
  obj.to('cpu')
 
 
 
44
  if any(buf.is_cuda for buf in obj.buffers()):
 
45
  obj.to('cpu')
 
 
46
  except Exception as e:
 
47
  pass
48
 
 
49
  def clear_gpu():
50
+ torch.cuda.empty_cache()
51
+ torch.cuda.synchronize()
52
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
55
 
56
+
57
  lora_path = "checkpoints/"
 
58
  if orbit_type == "Left":
59
  weight_name = "orbit_left_lora_weights.safetensors"
60
  elif orbit_type == "Up":
61
  weight_name = "orbit_up_lora_weights.safetensors"
62
  lora_rank = 256
63
 
64
+ pipe.unload_lora_weights()
65
+
66
  # Generate a timestamp for adapter_name
67
  adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
68
 
69
+ # Load LoRA weights on CPU, move to GPU afterward
70
  pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"{adapter_timestamp}")
71
  pipe.fuse_lora(lora_scale=1 / lora_rank)
72
+
73
+ # Move the pipeline to GPU for inference
74
  pipe.to("cuda")
75
+
76
+ # Set the inference prompt
77
  prompt = f"{prompt}. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
78
  image = load_image(image_path)
79
  seed = random.randint(0, 2**8 - 1)
80
 
81
+
82
  video = pipe(
83
  image,
84
  prompt,
85
+ num_inference_steps=50,
86
+ guidance_scale=7.0,
87
  use_dynamic_cfg=True,
88
  generator=torch.Generator(device="cpu").manual_seed(seed)
89
  )
 
 
 
90
 
91
+ # Generate and save output video
92
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
93
  export_to_video(video.frames[0], f"output_{timestamp}.mp4", fps=8)
94
+
95
+ # Move objects to CPU and clear GPU memory immediately after inference
96
+ find_and_move_object_to_cpu()
97
+ clear_gpu()
98
 
99
  return f"output_{timestamp}.mp4"
100
 
101
+ # Set up Gradio UI
102
  with gr.Blocks(analytics_enabled=False) as demo:
103
  with gr.Column(elem_id="col-container"):
104
  gr.Markdown("# DimensionX")