Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
from transformers import T5EncoderModel | |
from typing import Optional, Union | |
class T5TextConditionEncoder(nn.Module): | |
def __init__( | |
self, model_path, context_dim, | |
low_cpu_mem_usage: bool = True, device: Optional[str] = None, | |
dtype: Union[str, torch.dtype] = torch.float32, load_in_4bit: bool = False, load_in_8bit: bool = False | |
): | |
super().__init__() | |
self.encoder = T5EncoderModel.from_pretrained( | |
model_path, low_cpu_mem_usage=low_cpu_mem_usage, device_map=device, | |
torch_dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, | |
).encoder | |
self.projection = nn.Sequential( | |
nn.Linear(self.encoder.config.d_model, context_dim, bias=False), | |
nn.LayerNorm(context_dim) | |
) | |
def forward(self, model_input): | |
embeddings = self.encoder(**model_input).last_hidden_state | |
context = self.projection(embeddings) | |
if 'attention_mask' in model_input: | |
context_mask = model_input['attention_mask'] | |
context[context_mask == 0] = torch.zeros_like(context[context_mask == 0]) | |
max_seq_length = context_mask.sum(-1).max() + 1 | |
context = context[:, :max_seq_length] | |
context_mask = context_mask[:, :max_seq_length] | |
else: | |
context_mask = torch.ones(*embeddings.shape[:-1], dtype=torch.long, device=embeddings.device) | |
return context, context_mask | |
def get_condition_encoder(conf): | |
return T5TextConditionEncoder(**conf) | |