from typing import List, Optional, Union import torch from transformers import T5EncoderModel, T5Tokenizer def _get_t5_prompt_embeds( tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, prompt: Union[str, List[str]], num_videos_per_prompt: int = 1, max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) return prompt_embeds def encode_prompt( tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, prompt: Union[str, List[str]], num_videos_per_prompt: int = 1, max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt prompt_embeds = _get_t5_prompt_embeds( tokenizer, text_encoder, prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype, text_input_ids=text_input_ids, ) return prompt_embeds def compute_prompt_embeddings( tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, prompt: str, max_sequence_length: int, device: torch.device, dtype: torch.dtype, requires_grad: bool = False, ): if requires_grad: prompt_embeds = encode_prompt( tokenizer, text_encoder, prompt, num_videos_per_prompt=1, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) else: with torch.no_grad(): prompt_embeds = encode_prompt( tokenizer, text_encoder, prompt, num_videos_per_prompt=1, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) return prompt_embeds