Jannat24 commited on
Commit
c9b0750
·
verified ·
1 Parent(s): e7605fc

Update modules/finetunedvqgan.py

Browse files
Files changed (1) hide show
  1. modules/finetunedvqgan.py +7 -4
modules/finetunedvqgan.py CHANGED
@@ -7,8 +7,9 @@ from taming.models.vqgan import GumbelVQ
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
  class Generator:
10
- def __init__(self, config_path, device=device):
11
  self.config_path = config_path
 
12
  self.device = device
13
 
14
  def load_models(self):
@@ -25,7 +26,9 @@ class Generator:
25
  kl_weight=vq_params.kl_weight,
26
  temperature_scheduler_config=vq_params.temperature_scheduler_config,
27
  ).to(self.device)
 
 
 
 
28
 
29
- return model_vaq
30
-
31
-
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
  class Generator:
10
+ def __init__(self, config_path, checkpoint_path, device=device):
11
  self.config_path = config_path
12
+ self.checkpoint_path = checkpoint_path
13
  self.device = device
14
 
15
  def load_models(self):
 
26
  kl_weight=vq_params.kl_weight,
27
  temperature_scheduler_config=vq_params.temperature_scheduler_config,
28
  ).to(self.device)
29
+ # Load model checkpoints
30
+ checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
31
+ # Load the state dictionary into the models
32
+ model_vaq.load_state_dict(checkpoint, strict=True)
33
 
34
+ return model_vaq