YulianSa commited on
Commit
3019303
·
1 Parent(s): 3b1d9b4
Files changed (1) hide show
  1. infer_api.py +8 -9
infer_api.py CHANGED
@@ -171,10 +171,11 @@ def process_image(image, totensor, width, height):
171
  @spaces.GPU
172
  @torch.no_grad()
173
  def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
174
- text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type,
175
  use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
176
  set_seed(seed)
177
-
 
178
  totensor = transforms.ToTensor()
179
 
180
  prompts = "high quality, best quality"
@@ -278,10 +279,6 @@ def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
278
  generator = None
279
  else:
280
  generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
281
-
282
- if torch.cuda.is_available():
283
- pipeline.unet.enable_xformers_memory_efficient_attention()
284
- pipeline.to(device)
285
 
286
  images_cond = []
287
  results = {}
@@ -341,11 +338,14 @@ def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
341
  torch.cuda.empty_cache()
342
  return results
343
 
344
-
345
  def load_multiview_pipeline(cfg):
346
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
347
  cfg.pretrained_path,
348
  torch_dtype=torch.float16,)
 
 
 
349
  return pipeline
350
 
351
 
@@ -870,10 +870,9 @@ class InferCanonicalAPI:
870
  self.bkg_remover = BkgRemover()
871
 
872
  def canonicalize(self, image, seed):
873
- generator = torch.Generator(device=device).manual_seed(seed)
874
  return inference(
875
  self.validation_pipeline, self.bkg_remover, image, self.vae, self.feature_extractor, self.image_encoder, self.unet, self.ref_unet, self.tokenizer, self.text_encoder,
876
- self.pretrained_model_path, generator, self.validation, self.width_input, self.height_input, self.unet_condition_type,
877
  use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
878
  )
879
 
 
171
  @spaces.GPU
172
  @torch.no_grad()
173
  def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
174
+ text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
175
  use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
176
  set_seed(seed)
177
+ generator = torch.Generator(device=device).manual_seed(seed)
178
+
179
  totensor = transforms.ToTensor()
180
 
181
  prompts = "high quality, best quality"
 
279
  generator = None
280
  else:
281
  generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
 
 
 
 
282
 
283
  images_cond = []
284
  results = {}
 
338
  torch.cuda.empty_cache()
339
  return results
340
 
341
+ @spaces.GPU
342
  def load_multiview_pipeline(cfg):
343
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
344
  cfg.pretrained_path,
345
  torch_dtype=torch.float16,)
346
+ pipeline.unet.enable_xformers_memory_efficient_attention()
347
+ if torch.cuda.is_available():
348
+ pipeline.to(device)
349
  return pipeline
350
 
351
 
 
870
  self.bkg_remover = BkgRemover()
871
 
872
  def canonicalize(self, image, seed):
 
873
  return inference(
874
  self.validation_pipeline, self.bkg_remover, image, self.vae, self.feature_extractor, self.image_encoder, self.unet, self.ref_unet, self.tokenizer, self.text_encoder,
875
+ self.pretrained_model_path, self.validation, self.width_input, self.height_input, self.unet_condition_type,
876
  use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
877
  )
878