yuhj95 commited on
Commit
ea9d466
1 Parent(s): 65851d8

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. sampler.py +4 -1
sampler.py CHANGED
@@ -93,7 +93,10 @@ class BaseSampler:
93
  log_str = f'Building the diffusion model with length: {self.configs.diffusion.params.steps}...'
94
  self.write_log(log_str)
95
  self.base_diffusion = util_common.instantiate_from_config(self.configs.diffusion)
96
- model = util_common.instantiate_from_config(self.configs.model).cuda()
 
 
 
97
  ckpt_path =self.configs.model.ckpt_path
98
  assert ckpt_path is not None
99
  self.write_log(f'Loading Diffusion model from {ckpt_path}...')
 
93
  log_str = f'Building the diffusion model with length: {self.configs.diffusion.params.steps}...'
94
  self.write_log(log_str)
95
  self.base_diffusion = util_common.instantiate_from_config(self.configs.diffusion)
96
+ model = util_common.instantiate_from_config(self.configs.model)
97
+ # gpu test
98
+ if torch.cuda.is_available():
99
+ model = model.cuda()
100
  ckpt_path =self.configs.model.ckpt_path
101
  assert ckpt_path is not None
102
  self.write_log(f'Loading Diffusion model from {ckpt_path}...')