rhfeiyang commited on
Commit
b851c53
1 Parent(s): c62a333
Files changed (1) hide show
  1. inference.py +2 -1
inference.py CHANGED
@@ -354,6 +354,7 @@ def inference(network: LoRANetwork, tokenizer: CLIPTokenizer, text_encoder: CLIP
354
  latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t).to(weight_dtype)
355
  # predict the noise residual
356
  with network:
 
357
  noise_pred = unet(latent_model_input, t , encoder_hidden_states=text_embedding).sample
358
 
359
  # perform guidance
@@ -373,7 +374,7 @@ def inference(network: LoRANetwork, tokenizer: CLIPTokenizer, text_encoder: CLIP
373
  with torch.no_grad():
374
  image = vae.decode(latents).sample
375
  image = (image / 2 + 0.5).clamp(0, 1)
376
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
377
  images = (image * 255).round().astype("uint8")
378
 
379
 
 
354
  latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t).to(weight_dtype)
355
  # predict the noise residual
356
  with network:
357
+ print(f"dtype: {latent_model_input.dtype}, {text_embedding.dtype}, t={t}")
358
  noise_pred = unet(latent_model_input, t , encoder_hidden_states=text_embedding).sample
359
 
360
  # perform guidance
 
374
  with torch.no_grad():
375
  image = vae.decode(latents).sample
376
  image = (image / 2 + 0.5).clamp(0, 1)
377
+ image = image.detach().cpu().permute(0, 2, 3, 1).to(torch.float32).numpy()
378
  images = (image * 255).round().astype("uint8")
379
 
380