import torch def load_text_encoders(args, class_one, class_two): text_encoder_one = class_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = class_two.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) return text_encoder_one, text_encoder_two def tokenize_prompt(tokenizer, prompt, max_sequence_length): text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids return text_input_ids def tokenize_prompt_clip(tokenizer, prompt): text_inputs = tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids return text_input_ids def tokenize_prompt_t5(tokenizer, prompt): text_inputs = tokenizer( prompt, padding="max_length", max_length=512, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids return text_input_ids def _encode_prompt_with_t5( text_encoder, tokenizer, max_sequence_length=512, prompt=None, num_images_per_prompt=1, device=None, text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _encode_prompt_with_clip( text_encoder, tokenizer, prompt: str, device=None, text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def encode_prompt( text_encoders, tokenizers, prompt: str, max_sequence_length, device=None, num_images_per_prompt: int = 1, text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) prompt_embeds = _encode_prompt_with_t5( text_encoder=text_encoders[1], tokenizer=tokenizers[1], max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None): text_encoder_clip = text_encoders[0] text_encoder_t5 = text_encoders[1] tokens_clip, tokens_t5 = tokens[0], tokens[1] batch_size = tokens_clip.shape[0] if device == "cpu": device = "cpu" else: device = accelerator.device # clip prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModelpreprocess_train prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device) # t5 prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0] dtype = text_encoder_t5.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids