Update worker_runpod.py
Browse files- worker_runpod.py +12 -5
worker_runpod.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import os, json, requests, random, runpod
|
2 |
-
|
3 |
import torch
|
4 |
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel
|
5 |
from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
|
@@ -8,14 +7,22 @@ from transformers import T5EncoderModel, T5Tokenizer
|
|
8 |
|
9 |
with torch.inference_mode():
|
10 |
model_id = "/content/model"
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
14 |
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
|
|
|
|
15 |
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.float16).to("cuda")
|
|
|
16 |
lora_path = "/content/shirtlift.safetensors"
|
17 |
lora_weight = 1.0
|
18 |
-
|
|
|
|
|
|
|
19 |
# pipe.enable_model_cpu_offload()
|
20 |
|
21 |
def download_file(url, save_dir, file_name):
|
|
|
1 |
import os, json, requests, random, runpod
|
|
|
2 |
import torch
|
3 |
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel
|
4 |
from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
|
|
|
7 |
|
8 |
with torch.inference_mode():
|
9 |
model_id = "/content/model"
|
10 |
+
|
11 |
+
# Load models and ensure they are placed on CUDA device
|
12 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16).to("cuda")
|
13 |
+
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float16).to("cuda")
|
14 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16).to("cuda")
|
15 |
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
16 |
+
|
17 |
+
# Ensure the pipeline is on the same device (CUDA)
|
18 |
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.float16).to("cuda")
|
19 |
+
|
20 |
lora_path = "/content/shirtlift.safetensors"
|
21 |
lora_weight = 1.0
|
22 |
+
|
23 |
+
# Merge Lora model and ensure it's on the same device
|
24 |
+
pipe = merge_lora(pipe, lora_path, lora_weight).to("cuda")
|
25 |
+
|
26 |
# pipe.enable_model_cpu_offload()
|
27 |
|
28 |
def download_file(url, save_dir, file_name):
|