import torch from torch import nn from transformers import T5EncoderModel from typing import Optional, Union class T5TextConditionEncoder(nn.Module): def __init__( self, model_path, context_dim, low_cpu_mem_usage: bool = True, device: Optional[str] = None, dtype: Union[str, torch.dtype] = torch.float32, load_in_4bit: bool = False, load_in_8bit: bool = False ): super().__init__() self.encoder = T5EncoderModel.from_pretrained( model_path, low_cpu_mem_usage=low_cpu_mem_usage, device_map=device, torch_dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, ).encoder self.projection = nn.Sequential( nn.Linear(self.encoder.config.d_model, context_dim, bias=False), nn.LayerNorm(context_dim) ) def forward(self, model_input): embeddings = self.encoder(**model_input).last_hidden_state context = self.projection(embeddings) if 'attention_mask' in model_input: context_mask = model_input['attention_mask'] context[context_mask == 0] = torch.zeros_like(context[context_mask == 0]) max_seq_length = context_mask.sum(-1).max() + 1 context = context[:, :max_seq_length] context_mask = context_mask[:, :max_seq_length] else: context_mask = torch.ones(*embeddings.shape[:-1], dtype=torch.long, device=embeddings.device) return context, context_mask def get_condition_encoder(conf): return T5TextConditionEncoder(**conf)