Mairaaa commited on
Commit
468cc1d
·
verified ·
1 Parent(s): d3837fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -25,16 +25,16 @@ def load_models(pretrained_model_path, device):
25
  scheduler.set_timesteps(50)
26
 
27
  # Handle torch.hub checkpoint loading for CPU-only environments
28
- map_location = "cpu" if device.type == "cpu" else None
29
 
30
- # Load the UNet model
31
  unet = torch.hub.load(
32
  repo_or_dir="aimagelab/multimodal-garment-designer",
33
  source="github",
34
  model="mgd",
35
  pretrained=True,
36
  dataset="dresscode", # Change to "vitonhd" if needed
37
- map_location=map_location,
38
  )
39
 
40
  # Move UNet to the appropriate device
 
25
  scheduler.set_timesteps(50)
26
 
27
  # Handle torch.hub checkpoint loading for CPU-only environments
28
+ map_location = torch.device("cpu") if device.type == "cpu" else None
29
 
30
+ # Load the UNet model and map to CPU if necessary
31
  unet = torch.hub.load(
32
  repo_or_dir="aimagelab/multimodal-garment-designer",
33
  source="github",
34
  model="mgd",
35
  pretrained=True,
36
  dataset="dresscode", # Change to "vitonhd" if needed
37
+ map_location=map_location, # Ensure the model loads on CPU if needed
38
  )
39
 
40
  # Move UNet to the appropriate device