Laishram Pongthangamba Meitei commited on
Commit
a6c8077
·
verified ·
1 Parent(s): 4d2aae2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -17,8 +17,8 @@ torch.cuda.empty_cache()
17
 
18
  ## Load autoencoder
19
 
20
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
-
22
  autoencoderkl = AutoencoderKL(
23
  spatial_dims=2,
24
  in_channels=1,
@@ -33,7 +33,7 @@ autoencoderkl = AutoencoderKL(
33
  root_dir = "models"
34
  PATH_auto = f'{root_dir}/auto_encoder_model.pt'
35
 
36
- autoencoderkl.load_state_dict(torch.load(PATH_auto))
37
  autoencoderkl = autoencoderkl.to(device)
38
 
39
  #### Load unet and embedings
@@ -60,8 +60,8 @@ embed = torch.nn.Embedding(num_embeddings=6, embedding_dim=embedding_dimension,
60
  PATH_unet_condition = f'{root_dir}/unet_latent_space_model_condition.pt'
61
  PATH_embed_condition = f'{root_dir}/embed_latent_space_model_condition.pt'
62
 
63
- unet.load_state_dict(torch.load(PATH_unet_condition))
64
- embed.load_state_dict(torch.load(PATH_embed_condition))
65
 
66
  # unet.load_state_dict(checkpoint['model_state_dict'])
67
  # embed.load_state_dict(checkpoint['embed_state_dict'])
 
17
 
18
  ## Load autoencoder
19
 
20
+ #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+ device = torch.device('cpu')
22
  autoencoderkl = AutoencoderKL(
23
  spatial_dims=2,
24
  in_channels=1,
 
33
  root_dir = "models"
34
  PATH_auto = f'{root_dir}/auto_encoder_model.pt'
35
 
36
+ autoencoderkl.load_state_dict(torch.load(PATH_auto,map_location=device))
37
  autoencoderkl = autoencoderkl.to(device)
38
 
39
  #### Load unet and embedings
 
60
  PATH_unet_condition = f'{root_dir}/unet_latent_space_model_condition.pt'
61
  PATH_embed_condition = f'{root_dir}/embed_latent_space_model_condition.pt'
62
 
63
+ unet.load_state_dict(torch.load(PATH_unet_condition,map_location=device))
64
+ embed.load_state_dict(torch.load(PATH_embed_condition,map_location=device))
65
 
66
  # unet.load_state_dict(checkpoint['model_state_dict'])
67
  # embed.load_state_dict(checkpoint['embed_state_dict'])