Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
c9a29b0
·
verified ·
1 Parent(s): 17d2d66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -32,7 +32,7 @@ models_loaded = False
32
  def load_models():
33
  global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two
34
  global image_encoder, vae, UNet_Encoder, parsing_model, openpose_model, pipe
35
- global models_loaded # Déclarer la variable globale ici
36
 
37
  if not models_loaded:
38
  base_path = 'yisol/IDM-VTON'
@@ -47,7 +47,15 @@ def load_models():
47
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
48
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
49
  vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
50
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16, force_download=False)
 
 
 
 
 
 
 
 
51
 
52
  parsing_model = Parsing(0)
53
  openpose_model = OpenPose(0)
 
32
  def load_models():
33
  global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two
34
  global image_encoder, vae, UNet_Encoder, parsing_model, openpose_model, pipe
35
+ global models_loaded
36
 
37
  if not models_loaded:
38
  base_path = 'yisol/IDM-VTON'
 
47
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
48
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
49
  vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
50
+
51
+ # Set the correct encoder_hid_dim_type here
52
+ UNet_Encoder = UNet2DConditionModel.from_pretrained(
53
+ base_path,
54
+ subfolder="unet_encoder",
55
+ torch_dtype=torch.float16,
56
+ encoder_hid_dim_type="text_proj", # Update based on model type
57
+ force_download=False
58
+ )
59
 
60
  parsing_model = Parsing(0)
61
  openpose_model = OpenPose(0)