rhfeiyang commited on
Commit
a7876f7
·
1 Parent(s): e04ccbd
Files changed (1) hide show
  1. inference.py +2 -3
inference.py CHANGED
@@ -14,7 +14,6 @@ import sys
14
  import gc
15
  from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer
16
 
17
- from hf_demo import dtype
18
  # import train_util
19
 
20
  from utils.train_util import get_noisy_image, encode_prompts
@@ -320,8 +319,8 @@ def inference(network: LoRANetwork, tokenizer: CLIPTokenizer, text_encoder: CLIP
320
  uncond_embeddings = uncond_embed.repeat(bcz, 1, 1)
321
  else:
322
  uncond_embeddings = uncond_embed
323
- style_text_embeddings = torch.cat([uncond_embeddings, style_embeddings], dtype=weight_dtype)
324
- original_embeddings = torch.cat([uncond_embeddings, original_embeddings], dtype=weight_dtype)
325
 
326
  generator = torch.manual_seed(single_seed) if single_seed is not None else None
327
  noise_scheduler.set_timesteps(steps)
 
14
  import gc
15
  from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer
16
 
 
17
  # import train_util
18
 
19
  from utils.train_util import get_noisy_image, encode_prompts
 
319
  uncond_embeddings = uncond_embed.repeat(bcz, 1, 1)
320
  else:
321
  uncond_embeddings = uncond_embed
322
+ style_text_embeddings = torch.cat([uncond_embeddings, style_embeddings]).to(weight_dtype)
323
+ # original_embeddings = torch.cat([uncond_embeddings, original_embeddings]).to(weight_dtype)
324
 
325
  generator = torch.manual_seed(single_seed) if single_seed is not None else None
326
  noise_scheduler.set_timesteps(steps)