Mairaaa commited on
Commit
5b2f1b4
·
verified ·
1 Parent(s): f26e6e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -3
app.py CHANGED
@@ -11,16 +11,37 @@ from src.mgd_pipelines.mgd_pipe import MGDPipe
11
  def load_model():
12
  try:
13
  # Define your model loading logic
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
16
- print("VAE model loaded successfully.")
17
 
 
 
 
 
 
 
 
18
  tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
 
 
 
 
19
  text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
 
 
 
 
20
  unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
 
 
 
 
21
  scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
 
22
 
23
  # Initialize the pipeline
 
24
  pipe = MGDPipe(
25
  text_encoder=text_encoder,
26
  vae=vae,
@@ -28,7 +49,8 @@ def load_model():
28
  tokenizer=tokenizer,
29
  scheduler=scheduler,
30
  ).to(device)
31
- pipe.enable_attention_slicing() # Enable memory-efficient inference
 
32
  return pipe
33
  except Exception as e:
34
  print(f"Error loading the model: {e}")
 
11
  def load_model():
12
  try:
13
  # Define your model loading logic
14
+ print("Initializing model loading...")
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ print(f"Device selected: {device}")
 
17
 
18
+ # Load the VAE
19
+ print("Loading VAE...")
20
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
21
+ print("VAE loaded successfully.")
22
+
23
+ # Load the tokenizer
24
+ print("Loading tokenizer...")
25
  tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
26
+ print("Tokenizer loaded successfully.")
27
+
28
+ # Load the text encoder
29
+ print("Loading text encoder...")
30
  text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
31
+ print("Text encoder loaded successfully.")
32
+
33
+ # Load the UNet model
34
+ print("Loading UNet...")
35
  unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
36
+ print("UNet loaded successfully.")
37
+
38
+ # Load the scheduler
39
+ print("Loading scheduler...")
40
  scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
41
+ print("Scheduler loaded successfully.")
42
 
43
  # Initialize the pipeline
44
+ print("Initializing pipeline...")
45
  pipe = MGDPipe(
46
  text_encoder=text_encoder,
47
  vae=vae,
 
49
  tokenizer=tokenizer,
50
  scheduler=scheduler,
51
  ).to(device)
52
+ pipe.enable_attention_slicing()
53
+ print("Pipeline initialized successfully.")
54
  return pipe
55
  except Exception as e:
56
  print(f"Error loading the model: {e}")