majian0318
commited on
Commit
•
3cb067a
1
Parent(s):
015a181
Update README.md
Browse files
README.md
CHANGED
@@ -90,20 +90,21 @@ class StableDiffusionTest():
|
|
90 |
self.text_encoder, _, preprocess = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained=text_encoder_path)
|
91 |
self.tokenizer = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14')
|
92 |
self.text_encoder.text.output_tokens = True
|
93 |
-
self.
|
94 |
-
self.text_encoder = self.text_encoder.to(device)
|
95 |
|
96 |
self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
|
97 |
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
|
98 |
self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=scheduler,torch_dtype=dtype).to(device)
|
99 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.pipe.vae_scale_factor)
|
|
|
|
|
100 |
self.proj.load_state_dict(torch.load(proj_path, map_location="cpu"))
|
101 |
|
102 |
|
103 |
def encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
104 |
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
105 |
|
106 |
-
text_input_ids = self.tokenizer(prompt).to(device
|
107 |
_,text_embeddings = self.text_encoder.encode_text(text_input_ids)
|
108 |
|
109 |
add_text_embeds,text_embeddings_2048 = self.proj(text_embeddings)
|
|
|
90 |
self.text_encoder, _, preprocess = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained=text_encoder_path)
|
91 |
self.tokenizer = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14')
|
92 |
self.text_encoder.text.output_tokens = True
|
93 |
+
self.text_encoder = self.text_encoder.to(device,dtype=dtype)
|
|
|
94 |
|
95 |
self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
|
96 |
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
|
97 |
self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=scheduler,torch_dtype=dtype).to(device)
|
98 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.pipe.vae_scale_factor)
|
99 |
+
|
100 |
+
self.proj = MLP(1024, 1280, 1024,2048, use_residual=False).to(device,dtype=dtype)
|
101 |
self.proj.load_state_dict(torch.load(proj_path, map_location="cpu"))
|
102 |
|
103 |
|
104 |
def encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
105 |
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
106 |
|
107 |
+
text_input_ids = self.tokenizer(prompt).to(device)
|
108 |
_,text_embeddings = self.text_encoder.encode_text(text_input_ids)
|
109 |
|
110 |
add_text_embeds,text_embeddings_2048 = self.proj(text_embeddings)
|