Spaces:
Runtime error
Runtime error
Laishram Pongthangamba Meitei
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -17,8 +17,8 @@ torch.cuda.empty_cache()
|
|
17 |
|
18 |
## Load autoencoder
|
19 |
|
20 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
-
|
22 |
autoencoderkl = AutoencoderKL(
|
23 |
spatial_dims=2,
|
24 |
in_channels=1,
|
@@ -33,7 +33,7 @@ autoencoderkl = AutoencoderKL(
|
|
33 |
root_dir = "models"
|
34 |
PATH_auto = f'{root_dir}/auto_encoder_model.pt'
|
35 |
|
36 |
-
autoencoderkl.load_state_dict(torch.load(PATH_auto))
|
37 |
autoencoderkl = autoencoderkl.to(device)
|
38 |
|
39 |
#### Load unet and embedings
|
@@ -60,8 +60,8 @@ embed = torch.nn.Embedding(num_embeddings=6, embedding_dim=embedding_dimension,
|
|
60 |
PATH_unet_condition = f'{root_dir}/unet_latent_space_model_condition.pt'
|
61 |
PATH_embed_condition = f'{root_dir}/embed_latent_space_model_condition.pt'
|
62 |
|
63 |
-
unet.load_state_dict(torch.load(PATH_unet_condition))
|
64 |
-
embed.load_state_dict(torch.load(PATH_embed_condition))
|
65 |
|
66 |
# unet.load_state_dict(checkpoint['model_state_dict'])
|
67 |
# embed.load_state_dict(checkpoint['embed_state_dict'])
|
|
|
17 |
|
18 |
## Load autoencoder
|
19 |
|
20 |
+
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
+
device = torch.device('cpu')
|
22 |
autoencoderkl = AutoencoderKL(
|
23 |
spatial_dims=2,
|
24 |
in_channels=1,
|
|
|
33 |
root_dir = "models"
|
34 |
PATH_auto = f'{root_dir}/auto_encoder_model.pt'
|
35 |
|
36 |
+
autoencoderkl.load_state_dict(torch.load(PATH_auto,map_location=device))
|
37 |
autoencoderkl = autoencoderkl.to(device)
|
38 |
|
39 |
#### Load unet and embedings
|
|
|
60 |
PATH_unet_condition = f'{root_dir}/unet_latent_space_model_condition.pt'
|
61 |
PATH_embed_condition = f'{root_dir}/embed_latent_space_model_condition.pt'
|
62 |
|
63 |
+
unet.load_state_dict(torch.load(PATH_unet_condition,map_location=device))
|
64 |
+
embed.load_state_dict(torch.load(PATH_embed_condition,map_location=device))
|
65 |
|
66 |
# unet.load_state_dict(checkpoint['model_state_dict'])
|
67 |
# embed.load_state_dict(checkpoint['embed_state_dict'])
|