Text-to-Image
PyTorch
majian0318 commited on
Commit
3cb067a
1 Parent(s): 015a181

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -3
README.md CHANGED
@@ -90,20 +90,21 @@ class StableDiffusionTest():
90
  self.text_encoder, _, preprocess = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained=text_encoder_path)
91
  self.tokenizer = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14')
92
  self.text_encoder.text.output_tokens = True
93
- self.proj = MLP(1024, 1280, 1024,2048, use_residual=False).to(device,dtype=dtype)
94
- self.text_encoder = self.text_encoder.to(device)
95
 
96
  self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
97
  scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
98
  self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=scheduler,torch_dtype=dtype).to(device)
99
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.pipe.vae_scale_factor)
 
 
100
  self.proj.load_state_dict(torch.load(proj_path, map_location="cpu"))
101
 
102
 
103
  def encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
104
  batch_size = len(prompt) if isinstance(prompt, list) else 1
105
 
106
- text_input_ids = self.tokenizer(prompt).to(device,dtype=dtype)
107
  _,text_embeddings = self.text_encoder.encode_text(text_input_ids)
108
 
109
  add_text_embeds,text_embeddings_2048 = self.proj(text_embeddings)
 
90
  self.text_encoder, _, preprocess = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained=text_encoder_path)
91
  self.tokenizer = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14')
92
  self.text_encoder.text.output_tokens = True
93
+ self.text_encoder = self.text_encoder.to(device,dtype=dtype)
 
94
 
95
  self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
96
  scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
97
  self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=scheduler,torch_dtype=dtype).to(device)
98
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.pipe.vae_scale_factor)
99
+
100
+ self.proj = MLP(1024, 1280, 1024,2048, use_residual=False).to(device,dtype=dtype)
101
  self.proj.load_state_dict(torch.load(proj_path, map_location="cpu"))
102
 
103
 
104
  def encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
105
  batch_size = len(prompt) if isinstance(prompt, list) else 1
106
 
107
+ text_input_ids = self.tokenizer(prompt).to(device)
108
  _,text_embeddings = self.text_encoder.encode_text(text_input_ids)
109
 
110
  add_text_embeds,text_embeddings_2048 = self.proj(text_embeddings)