Mairaaa commited on
Commit
ecedcde
·
verified ·
1 Parent(s): 468cc1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -27,16 +27,22 @@ def load_models(pretrained_model_path, device):
27
  # Handle torch.hub checkpoint loading for CPU-only environments
28
  map_location = torch.device("cpu") if device.type == "cpu" else None
29
 
30
- # Load the UNet model and map to CPU if necessary
31
  unet = torch.hub.load(
32
  repo_or_dir="aimagelab/multimodal-garment-designer",
33
  source="github",
34
  model="mgd",
35
  pretrained=True,
36
  dataset="dresscode", # Change to "vitonhd" if needed
37
- map_location=map_location, # Ensure the model loads on CPU if needed
38
  )
39
 
 
 
 
 
 
 
 
40
  # Move UNet to the appropriate device
41
  unet = unet.to(device)
42
 
 
27
  # Handle torch.hub checkpoint loading for CPU-only environments
28
  map_location = torch.device("cpu") if device.type == "cpu" else None
29
 
30
+ # Load the UNet model and force map_location for state_dict loading
31
  unet = torch.hub.load(
32
  repo_or_dir="aimagelab/multimodal-garment-designer",
33
  source="github",
34
  model="mgd",
35
  pretrained=True,
36
  dataset="dresscode", # Change to "vitonhd" if needed
 
37
  )
38
 
39
+ # Ensure the model state dict is mapped correctly to the CPU if needed
40
+ if device.type == "cpu":
41
+ checkpoint_url = unet.config.get("checkpoint")
42
+ if checkpoint_url:
43
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
44
+ unet.load_state_dict(state_dict)
45
+
46
  # Move UNet to the appropriate device
47
  unet = unet.to(device)
48