chaojiemao commited on
Commit
3c39b17
·
verified ·
1 Parent(s): 4ae4bd3

Update ace_inference.py

Browse files
Files changed (1) hide show
  1. ace_inference.py +9 -2
ace_inference.py CHANGED
@@ -146,8 +146,15 @@ class ACEInference(DiffusionInference):
146
  self.dynamic_load(self.first_stage_model, 'first_stage_model')
147
  self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
148
  if self.ref_cond_stage_model is not None: self.dynamic_load(self.ref_cond_stage_model, 'ref_cond_stage_model')
149
- self.dynamic_load(self.diffusion_model, 'diffusion_model')
150
- self.diffusion_model["model"].to(torch.bfloat16)
 
 
 
 
 
 
 
151
 
152
  def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
153
  c, H, W = image.shape
 
146
  self.dynamic_load(self.first_stage_model, 'first_stage_model')
147
  self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
148
  if self.ref_cond_stage_model is not None: self.dynamic_load(self.ref_cond_stage_model, 'ref_cond_stage_model')
149
+ # self.dynamic_load(self.diffusion_model, 'diffusion_model')
150
+ # self.diffusion_model["model"].to(torch.bfloat16)
151
+ with torch.device("meta"):
152
+ pretrained_model = self.diffusion_model['cfg'].PRETRAINED_MODEL
153
+ self.diffusion_model['cfg'].PRETRAINED_MODEL = None
154
+ self.diffusion_model['model'] = BACKBONES.build(self.diffusion_model['cfg'], logger=self.logger).eval()
155
+ # self.dynamic_load(self.diffusion_model, 'diffusion_model')
156
+ self.diffusion_model['model'].load_pretrained_model(pretrained_model)
157
+ self.diffusion_model['device'] = we.device_id
158
 
159
  def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
160
  c, H, W = image.shape