primerz commited on
Commit
58a5607
·
verified ·
1 Parent(s): bd63113

Update cog_sdxl_dataset_and_utils.py

Browse files
Files changed (1) hide show
  1. cog_sdxl_dataset_and_utils.py +80 -325
cog_sdxl_dataset_and_utils.py CHANGED
@@ -1,4 +1,4 @@
1
- # dataset_and_utils.py file taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
2
  import os
3
  from typing import Dict, List, Optional, Tuple
4
 
@@ -15,26 +15,22 @@ from torch.utils.data import Dataset
15
  from transformers import AutoTokenizer, PretrainedConfig
16
 
17
 
18
- def prepare_image(
19
- pil_image: PIL.Image.Image, w: int = 512, h: int = 512
20
- ) -> torch.Tensor:
21
- pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
22
- arr = np.array(pil_image.convert("RGB"))
23
- arr = arr.astype(np.float32) / 127.5 - 1
24
- arr = np.transpose(arr, [2, 0, 1])
25
- image = torch.from_numpy(arr).unsqueeze(0)
26
- return image
27
 
28
 
29
- def prepare_mask(
30
- pil_image: PIL.Image.Image, w: int = 512, h: int = 512
31
- ) -> torch.Tensor:
32
- pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
33
- arr = np.array(pil_image.convert("L"))
34
- arr = arr.astype(np.float32) / 255.0
35
- arr = np.expand_dims(arr, 0)
36
- image = torch.from_numpy(arr).unsqueeze(0)
37
- return image
38
 
39
 
40
  class PreprocessedDataset(Dataset):
@@ -50,373 +46,132 @@ class PreprocessedDataset(Dataset):
50
  size: int = 512,
51
  text_dropout: float = 0.0,
52
  scale_vae_latents: bool = True,
53
- substitute_caption_map: Dict[str, str] = {},
54
  ):
 
 
 
55
  super().__init__()
56
-
57
  self.data = pd.read_csv(csv_path)
 
 
 
58
  self.csv_path = csv_path
 
 
 
 
59
 
60
- self.caption = self.data["caption"]
61
- # make it lowercase
62
- self.caption = self.caption.str.lower()
63
- for key, value in substitute_caption_map.items():
64
- self.caption = self.caption.str.replace(key.lower(), value)
65
 
66
- self.image_path = self.data["image_path"]
 
 
67
 
68
- if "mask_path" not in self.data.columns:
69
- self.mask_path = None
70
- else:
71
- self.mask_path = self.data["mask_path"]
72
 
73
- if text_encoder_1 is None:
74
- self.return_text_embeddings = False
75
- else:
76
  self.text_encoder_1 = text_encoder_1
77
  self.text_encoder_2 = text_encoder_2
78
  self.return_text_embeddings = True
79
- assert (
80
- NotImplementedError
81
- ), "Preprocessing Text Encoder is not implemented yet"
82
-
83
- self.tokenizer_1 = tokenizer_1
84
- self.tokenizer_2 = tokenizer_2
85
-
86
- self.vae_encoder = vae_encoder
87
- self.scale_vae_latents = scale_vae_latents
88
- self.text_dropout = text_dropout
89
-
90
- self.size = size
91
 
92
- if do_cache:
93
  self.vae_latents = []
94
  self.tokens_tuple = []
95
  self.masks = []
96
-
97
- self.do_cache = True
98
-
99
- print("Captions to train on: ")
100
  for idx in range(len(self.data)):
101
  token, vae_latent, mask = self._process(idx)
102
- self.vae_latents.append(vae_latent)
103
  self.tokens_tuple.append(token)
 
104
  self.masks.append(mask)
105
-
106
- del self.vae_encoder
107
-
108
- else:
109
- self.do_cache = False
110
 
111
  @torch.no_grad()
112
- def _process(
113
- self, idx: int
114
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
115
- image_path = self.image_path[idx]
116
- image_path = os.path.join(os.path.dirname(self.csv_path), image_path)
117
-
118
- image = PIL.Image.open(image_path).convert("RGB")
119
- image = prepare_image(image, self.size, self.size).to(
120
  dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
121
  )
122
 
123
  caption = self.caption[idx]
124
-
125
- print(caption)
126
-
127
- # tokenizer_1
128
- ti1 = self.tokenizer_1(
129
- caption,
130
- padding="max_length",
131
- max_length=77,
132
- truncation=True,
133
- add_special_tokens=True,
134
- return_tensors="pt",
135
- ).input_ids
136
-
137
- ti2 = self.tokenizer_2(
138
- caption,
139
- padding="max_length",
140
- max_length=77,
141
- truncation=True,
142
- add_special_tokens=True,
143
- return_tensors="pt",
144
- ).input_ids
145
 
146
  vae_latent = self.vae_encoder.encode(image).latent_dist.sample()
147
-
148
  if self.scale_vae_latents:
149
- vae_latent = vae_latent * self.vae_encoder.config.scaling_factor
150
 
151
  if self.mask_path is None:
152
- mask = torch.ones_like(
153
- vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
154
- )
155
-
156
  else:
157
- mask_path = self.mask_path[idx]
158
- mask_path = os.path.join(os.path.dirname(self.csv_path), mask_path)
159
-
160
- mask = PIL.Image.open(mask_path)
161
- mask = prepare_mask(mask, self.size, self.size).to(
162
  dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
163
  )
164
-
165
- mask = torch.nn.functional.interpolate(
166
- mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest"
167
- )
168
  mask = mask.repeat(1, vae_latent.shape[1], 1, 1)
169
 
170
- assert len(mask.shape) == 4 and len(vae_latent.shape) == 4
171
 
172
  return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze()
173
 
174
  def __len__(self) -> int:
175
  return len(self.data)
176
 
177
- def atidx(
178
- self, idx: int
179
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
180
- if self.do_cache:
181
- return self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx]
182
- else:
183
- return self._process(idx)
184
 
185
- def __getitem__(
186
- self, idx: int
187
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
188
- token, vae_latent, mask = self.atidx(idx)
189
- return token, vae_latent, mask
190
 
191
 
192
- def import_model_class_from_model_name_or_path(
193
- pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
194
- ):
195
- text_encoder_config = PretrainedConfig.from_pretrained(
196
- pretrained_model_name_or_path, subfolder=subfolder, revision=revision
197
- )
198
- model_class = text_encoder_config.architectures[0]
199
 
200
  if model_class == "CLIPTextModel":
201
  from transformers import CLIPTextModel
202
-
203
  return CLIPTextModel
204
  elif model_class == "CLIPTextModelWithProjection":
205
  from transformers import CLIPTextModelWithProjection
206
-
207
  return CLIPTextModelWithProjection
208
  else:
209
- raise ValueError(f"{model_class} is not supported.")
210
 
211
 
212
  def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
213
- tokenizer_one = AutoTokenizer.from_pretrained(
214
- pretrained_model_name_or_path,
215
- subfolder="tokenizer",
216
- revision=revision,
217
- use_fast=False,
218
- )
219
- tokenizer_two = AutoTokenizer.from_pretrained(
220
- pretrained_model_name_or_path,
221
- subfolder="tokenizer_2",
222
- revision=revision,
223
- use_fast=False,
224
- )
225
-
226
- # Load scheduler and models
227
- noise_scheduler = DDPMScheduler.from_pretrained(
228
- pretrained_model_name_or_path, subfolder="scheduler"
229
- )
230
- # import correct text encoder classes
231
- text_encoder_cls_one = import_model_class_from_model_name_or_path(
232
- pretrained_model_name_or_path, revision
233
- )
234
- text_encoder_cls_two = import_model_class_from_model_name_or_path(
235
- pretrained_model_name_or_path, revision, subfolder="text_encoder_2"
236
- )
237
- text_encoder_one = text_encoder_cls_one.from_pretrained(
238
- pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
239
- )
240
- text_encoder_two = text_encoder_cls_two.from_pretrained(
241
- pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision
242
- )
243
-
244
- vae = AutoencoderKL.from_pretrained(
245
- pretrained_model_name_or_path, subfolder="vae", revision=revision
246
- )
247
- unet = UNet2DConditionModel.from_pretrained(
248
- pretrained_model_name_or_path, subfolder="unet", revision=revision
249
- )
250
-
251
- vae.requires_grad_(False)
252
- text_encoder_one.requires_grad_(False)
253
- text_encoder_two.requires_grad_(False)
254
-
255
- unet.to(device, dtype=weight_dtype)
256
- vae.to(device, dtype=torch.float32)
257
- text_encoder_one.to(device, dtype=weight_dtype)
258
- text_encoder_two.to(device, dtype=weight_dtype)
259
-
260
- return (
261
- tokenizer_one,
262
- tokenizer_two,
263
- noise_scheduler,
264
- text_encoder_one,
265
- text_encoder_two,
266
- vae,
267
- unet,
268
- )
269
-
270
-
271
- def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
272
  """
273
- Returns:
274
- a state dict containing just the attention processor parameters.
275
  """
276
- attn_processors = unet.attn_processors
277
-
278
- attn_processors_state_dict = {}
279
-
280
- for attn_processor_key, attn_processor in attn_processors.items():
281
- for parameter_key, parameter in attn_processor.state_dict().items():
282
- attn_processors_state_dict[
283
- f"{attn_processor_key}.{parameter_key}"
284
- ] = parameter
285
-
286
- return attn_processors_state_dict
287
-
288
-
289
- class TokenEmbeddingsHandler:
290
- def __init__(self, text_encoders, tokenizers):
291
- self.text_encoders = text_encoders
292
- self.tokenizers = tokenizers
293
 
294
- self.train_ids: Optional[torch.Tensor] = None
295
- self.inserting_toks: Optional[List[str]] = None
296
- self.embeddings_settings = {}
297
 
298
- def initialize_new_tokens(self, inserting_toks: List[str]):
299
- idx = 0
300
- for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
301
- assert isinstance(
302
- inserting_toks, list
303
- ), "inserting_toks should be a list of strings."
304
- assert all(
305
- isinstance(tok, str) for tok in inserting_toks
306
- ), "All elements in inserting_toks should be strings."
307
 
308
- self.inserting_toks = inserting_toks
309
- special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
310
- tokenizer.add_special_tokens(special_tokens_dict)
311
- text_encoder.resize_token_embeddings(len(tokenizer))
312
 
313
- self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
 
314
 
315
- # random initialization of new tokens
 
 
316
 
317
- std_token_embedding = (
318
- text_encoder.text_model.embeddings.token_embedding.weight.data.std()
319
- )
320
-
321
- print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}")
322
-
323
- text_encoder.text_model.embeddings.token_embedding.weight.data[
324
- self.train_ids
325
- ] = (
326
- torch.randn(
327
- len(self.train_ids), text_encoder.text_model.config.hidden_size
328
- )
329
- .to(device=self.device)
330
- .to(dtype=self.dtype)
331
- * std_token_embedding
332
- )
333
- self.embeddings_settings[
334
- f"original_embeddings_{idx}"
335
- ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
336
- self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
337
-
338
- inu = torch.ones((len(tokenizer),), dtype=torch.bool)
339
- inu[self.train_ids] = False
340
-
341
- self.embeddings_settings[f"index_no_updates_{idx}"] = inu
342
-
343
- print(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
344
-
345
- idx += 1
346
-
347
- def save_embeddings(self, file_path: str):
348
- assert (
349
- self.train_ids is not None
350
- ), "Initialize new tokens before saving embeddings."
351
- tensors = {}
352
- for idx, text_encoder in enumerate(self.text_encoders):
353
- assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[
354
- 0
355
- ] == len(self.tokenizers[0]), "Tokenizers should be the same."
356
- new_token_embeddings = (
357
- text_encoder.text_model.embeddings.token_embedding.weight.data[
358
- self.train_ids
359
- ]
360
- )
361
- tensors[f"text_encoders_{idx}"] = new_token_embeddings
362
-
363
- save_file(tensors, file_path)
364
-
365
- @property
366
- def dtype(self):
367
- return self.text_encoders[0].dtype
368
-
369
- @property
370
- def device(self):
371
- return self.text_encoders[0].device
372
-
373
- def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
374
- # Assuming new tokens are of the format <s_i>
375
- self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
376
- special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
377
- tokenizer.add_special_tokens(special_tokens_dict)
378
- text_encoder.resize_token_embeddings(len(tokenizer))
379
-
380
- self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
381
- assert self.train_ids is not None, "New tokens could not be converted to IDs."
382
- text_encoder.text_model.embeddings.token_embedding.weight.data[
383
- self.train_ids
384
- ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
385
-
386
- @torch.no_grad()
387
- def retract_embeddings(self):
388
- for idx, text_encoder in enumerate(self.text_encoders):
389
- index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
390
- text_encoder.text_model.embeddings.token_embedding.weight.data[
391
- index_no_updates
392
- ] = (
393
- self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
394
- .to(device=text_encoder.device)
395
- .to(dtype=text_encoder.dtype)
396
- )
397
-
398
- # for the parts that were updated, we need to normalize them
399
- # to have the same std as before
400
- std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
401
-
402
- index_updates = ~index_no_updates
403
- new_embeddings = (
404
- text_encoder.text_model.embeddings.token_embedding.weight.data[
405
- index_updates
406
- ]
407
- )
408
- off_ratio = std_token_embedding / new_embeddings.std()
409
-
410
- new_embeddings = new_embeddings * (off_ratio**0.1)
411
- text_encoder.text_model.embeddings.token_embedding.weight.data[
412
- index_updates
413
- ] = new_embeddings
414
-
415
- def load_embeddings(self, file_path: str):
416
- with safe_open(file_path, framework="pt", device=self.device.type) as f:
417
- for idx in range(len(self.text_encoders)):
418
- text_encoder = self.text_encoders[idx]
419
- tokenizer = self.tokenizers[idx]
420
 
421
- loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
422
- self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
 
1
+ # dataset_and_utils.py - Optimized and Improved Version
2
  import os
3
  from typing import Dict, List, Optional, Tuple
4
 
 
15
  from transformers import AutoTokenizer, PretrainedConfig
16
 
17
 
18
+ def prepare_image(image: PIL.Image.Image, width: int = 512, height: int = 512) -> torch.Tensor:
19
+ """
20
+ Prepares an image for model input by resizing and normalizing it.
21
+ """
22
+ image = image.resize((width, height), resample=Image.BICUBIC, reducing_gap=1)
23
+ arr = np.array(image.convert("RGB"), dtype=np.float32) / 127.5 - 1
24
+ return torch.from_numpy(np.transpose(arr, (2, 0, 1))).unsqueeze(0)
 
 
25
 
26
 
27
+ def prepare_mask(mask: PIL.Image.Image, width: int = 512, height: int = 512) -> torch.Tensor:
28
+ """
29
+ Prepares a mask image for model input by resizing and normalizing it.
30
+ """
31
+ mask = mask.resize((width, height), resample=Image.BICUBIC, reducing_gap=1)
32
+ arr = np.array(mask.convert("L"), dtype=np.float32) / 255.0
33
+ return torch.from_numpy(np.expand_dims(arr, 0)).unsqueeze(0)
 
 
34
 
35
 
36
  class PreprocessedDataset(Dataset):
 
46
  size: int = 512,
47
  text_dropout: float = 0.0,
48
  scale_vae_latents: bool = True,
49
+ substitute_caption_map: Dict[str, str] = None,
50
  ):
51
+ """
52
+ Dataset class that pre-processes images, masks, and text data for training.
53
+ """
54
  super().__init__()
 
55
  self.data = pd.read_csv(csv_path)
56
+ self.size = size
57
+ self.scale_vae_latents = scale_vae_latents
58
+ self.text_dropout = text_dropout
59
  self.csv_path = csv_path
60
+ self.tokenizer_1 = tokenizer_1
61
+ self.tokenizer_2 = tokenizer_2
62
+ self.vae_encoder = vae_encoder
63
+ self.do_cache = do_cache
64
 
65
+ self.caption = self.data["caption"].str.lower()
 
 
 
 
66
 
67
+ if substitute_caption_map:
68
+ for key, value in substitute_caption_map.items():
69
+ self.caption = self.caption.str.replace(key.lower(), value)
70
 
71
+ self.image_path = self.data["image_path"]
72
+ self.mask_path = self.data["mask_path"] if "mask_path" in self.data.columns else None
 
 
73
 
74
+ if text_encoder_1:
 
 
75
  self.text_encoder_1 = text_encoder_1
76
  self.text_encoder_2 = text_encoder_2
77
  self.return_text_embeddings = True
78
+ raise NotImplementedError("Preprocessing for text encoder is not implemented yet.")
79
+ else:
80
+ self.return_text_embeddings = False
 
 
 
 
 
 
 
 
 
81
 
82
+ if self.do_cache:
83
  self.vae_latents = []
84
  self.tokens_tuple = []
85
  self.masks = []
86
+ print("Caching dataset...")
 
 
 
87
  for idx in range(len(self.data)):
88
  token, vae_latent, mask = self._process(idx)
 
89
  self.tokens_tuple.append(token)
90
+ self.vae_latents.append(vae_latent)
91
  self.masks.append(mask)
92
+ del self.vae_encoder # Free up memory
 
 
 
 
93
 
94
  @torch.no_grad()
95
+ def _process(self, idx: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
96
+ """
97
+ Internal function to process images, text, and masks for a given index.
98
+ """
99
+ image_path = os.path.join(os.path.dirname(self.csv_path), self.image_path[idx])
100
+ image = prepare_image(Image.open(image_path).convert("RGB"), self.size, self.size).to(
 
 
101
  dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
102
  )
103
 
104
  caption = self.caption[idx]
105
+ ti1 = self.tokenizer_1(caption, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
106
+ ti2 = self.tokenizer_2(caption, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  vae_latent = self.vae_encoder.encode(image).latent_dist.sample()
 
109
  if self.scale_vae_latents:
110
+ vae_latent *= self.vae_encoder.config.scaling_factor
111
 
112
  if self.mask_path is None:
113
+ mask = torch.ones_like(vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device)
 
 
 
114
  else:
115
+ mask_path = os.path.join(os.path.dirname(self.csv_path), self.mask_path[idx])
116
+ mask = prepare_mask(Image.open(mask_path), self.size, self.size).to(
 
 
 
117
  dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
118
  )
119
+ mask = torch.nn.functional.interpolate(mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest")
 
 
 
120
  mask = mask.repeat(1, vae_latent.shape[1], 1, 1)
121
 
122
+ assert mask.shape == vae_latent.shape, "Mask and latent dimensions must match."
123
 
124
  return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze()
125
 
126
  def __len__(self) -> int:
127
  return len(self.data)
128
 
129
+ def __getitem__(self, idx: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
130
+ return self.atidx(idx)
 
 
 
 
 
131
 
132
+ def atidx(self, idx: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
133
+ return self._process(idx) if not self.do_cache else (self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx])
 
 
 
134
 
135
 
136
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
137
+ """
138
+ Dynamically imports a model class based on configuration.
139
+ """
140
+ config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, revision=revision)
141
+ model_class = config.architectures[0]
 
142
 
143
  if model_class == "CLIPTextModel":
144
  from transformers import CLIPTextModel
 
145
  return CLIPTextModel
146
  elif model_class == "CLIPTextModelWithProjection":
147
  from transformers import CLIPTextModelWithProjection
 
148
  return CLIPTextModelWithProjection
149
  else:
150
+ raise ValueError(f"Unsupported model class: {model_class}")
151
 
152
 
153
  def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  """
155
+ Loads required models from a given pretrained path.
 
156
  """
157
+ tokenizer_1 = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", revision=revision, use_fast=False)
158
+ tokenizer_2 = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2", revision=revision, use_fast=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
 
 
161
 
162
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(pretrained_model_name_or_path, revision)
163
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(pretrained_model_name_or_path, revision, subfolder="text_encoder_2")
 
 
 
 
 
 
 
164
 
165
+ text_encoder_1 = text_encoder_cls_one.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", revision=revision)
166
+ text_encoder_2 = text_encoder_cls_two.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision)
 
 
167
 
168
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
169
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", revision=revision)
170
 
171
+ for model in [vae, text_encoder_1, text_encoder_2]:
172
+ model.requires_grad_(False)
173
+ model.to(device, dtype=weight_dtype)
174
 
175
+ unet.to(device, dtype=weight_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet