NIRVANALAN commited on
Commit
cb12c31
1 Parent(s): 945a01c
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -175,8 +175,10 @@ def main(args):
175
  # denoise_model.load_state_dict(
176
  # dist_util.load_state_dict(args.ddpm_model_path, map_location="cpu"))
177
  denoise_model.to(dist_util.dev())
178
- if args.use_fp16:
179
- denoise_model.convert_to_fp16()
 
 
180
  denoise_model.eval()
181
 
182
  # * auto-encoder reconstruction model
 
175
  # denoise_model.load_state_dict(
176
  # dist_util.load_state_dict(args.ddpm_model_path, map_location="cpu"))
177
  denoise_model.to(dist_util.dev())
178
+ denoise_model = denoise_model.to(th.bfloat16)
179
+ auto_encoder = auto_encoder.to(th.bfloat16)
180
+ # if args.use_fp16:
181
+ # denoise_model.convert_to_fp16()
182
  denoise_model.eval()
183
 
184
  # * auto-encoder reconstruction model