caohy666 commited on
Commit
8e7659c
·
1 Parent(s): 8748eeb

<fix> move transformer init to process_image_and_text.

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -52,9 +52,6 @@ def init_basemodel():
52
  global transformer, scheduler, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, image_processor
53
 
54
  # init models
55
- transformer = HunyuanVideoTransformer3DModel.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
56
- subfolder="transformer",
57
- inference_subject_driven=task in ['subject_driven'])
58
  scheduler = diffusers.FlowMatchEulerDiscreteScheduler()
59
  vae = diffusers.AutoencoderKLHunyuanVideo.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
60
  subfolder="vae")
@@ -72,11 +69,9 @@ def init_basemodel():
72
  device = "cuda" if torch.cuda.is_available() else "cpu"
73
  weight_dtype = torch.bfloat16
74
 
75
- transformer.requires_grad_(False)
76
  vae.requires_grad_(False).to(device, dtype=weight_dtype)
77
  text_encoder.requires_grad_(False).to(device, dtype=weight_dtype)
78
  text_encoder_2.requires_grad_(False).to(device, dtype=weight_dtype)
79
- transformer.to(device, dtype=weight_dtype)
80
  vae.enable_tiling()
81
  vae.enable_slicing()
82
 
@@ -85,6 +80,13 @@ def init_basemodel():
85
  def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
86
  # set up the model
87
  if pipe is None or current_task != task:
 
 
 
 
 
 
 
88
  # insert LoRA
89
  lora_config = LoraConfig(
90
  r=16,
 
52
  global transformer, scheduler, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, image_processor
53
 
54
  # init models
 
 
 
55
  scheduler = diffusers.FlowMatchEulerDiscreteScheduler()
56
  vae = diffusers.AutoencoderKLHunyuanVideo.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
57
  subfolder="vae")
 
69
  device = "cuda" if torch.cuda.is_available() else "cpu"
70
  weight_dtype = torch.bfloat16
71
 
 
72
  vae.requires_grad_(False).to(device, dtype=weight_dtype)
73
  text_encoder.requires_grad_(False).to(device, dtype=weight_dtype)
74
  text_encoder_2.requires_grad_(False).to(device, dtype=weight_dtype)
 
75
  vae.enable_tiling()
76
  vae.enable_slicing()
77
 
 
80
  def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
81
  # set up the model
82
  if pipe is None or current_task != task:
83
+ # init transformer
84
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
85
+ subfolder="transformer",
86
+ inference_subject_driven=task in ['subject_driven'])
87
+ transformer.requires_grad_(False)
88
+ transformer.to("cuda" if torch.cuda.is_available() else "cpu", dtype=torch.bfloat16)
89
+
90
  # insert LoRA
91
  lora_config = LoraConfig(
92
  r=16,