Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -27,16 +27,22 @@ def load_models(pretrained_model_path, device):
|
|
27 |
# Handle torch.hub checkpoint loading for CPU-only environments
|
28 |
map_location = torch.device("cpu") if device.type == "cpu" else None
|
29 |
|
30 |
-
# Load the UNet model and
|
31 |
unet = torch.hub.load(
|
32 |
repo_or_dir="aimagelab/multimodal-garment-designer",
|
33 |
source="github",
|
34 |
model="mgd",
|
35 |
pretrained=True,
|
36 |
dataset="dresscode", # Change to "vitonhd" if needed
|
37 |
-
map_location=map_location, # Ensure the model loads on CPU if needed
|
38 |
)
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# Move UNet to the appropriate device
|
41 |
unet = unet.to(device)
|
42 |
|
|
|
27 |
# Handle torch.hub checkpoint loading for CPU-only environments
|
28 |
map_location = torch.device("cpu") if device.type == "cpu" else None
|
29 |
|
30 |
+
# Load the UNet model and force map_location for state_dict loading
|
31 |
unet = torch.hub.load(
|
32 |
repo_or_dir="aimagelab/multimodal-garment-designer",
|
33 |
source="github",
|
34 |
model="mgd",
|
35 |
pretrained=True,
|
36 |
dataset="dresscode", # Change to "vitonhd" if needed
|
|
|
37 |
)
|
38 |
|
39 |
+
# Ensure the model state dict is mapped correctly to the CPU if needed
|
40 |
+
if device.type == "cpu":
|
41 |
+
checkpoint_url = unet.config.get("checkpoint")
|
42 |
+
if checkpoint_url:
|
43 |
+
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
|
44 |
+
unet.load_state_dict(state_dict)
|
45 |
+
|
46 |
# Move UNet to the appropriate device
|
47 |
unet = unet.to(device)
|
48 |
|