LPX55 commited on
Commit
0eafa0e
·
verified ·
1 Parent(s): 24d345b

Update app_optimized.py

Browse files
Files changed (1) hide show
  1. app_optimized.py +2 -1
app_optimized.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import safetensors.torch
3
  import torchvision.transforms.v2 as transforms
@@ -62,7 +63,7 @@ with torch.no_grad(): # enable image inputs
62
  new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
63
  pipe.transformer.x_embedder = new_img_in
64
 
65
- lora_state_dict = safetensors.torch.load_file(lora_path, device="cpu")
66
  transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
67
  pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
68
  pipe.set_adapters(["i2v"], adapter_weights=[1.0])
 
1
+ import os
2
  import gradio as gr
3
  import safetensors.torch
4
  import torchvision.transforms.v2 as transforms
 
63
  new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
64
  pipe.transformer.x_embedder = new_img_in
65
 
66
+ lora_state_dict = safetensors.torch.load_file(lora_path, device="cuda")
67
  transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
68
  pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
69
  pipe.set_adapters(["i2v"], adapter_weights=[1.0])