slides_generator / Kandinsky-3 /kandinsky3 /condition_encoders.py
nesterus
moved contents of presentations repo
d90acf0
raw
history blame
1.58 kB
import torch
from torch import nn
from transformers import T5EncoderModel
from typing import Optional, Union
class T5TextConditionEncoder(nn.Module):
def __init__(
self, model_path, context_dim,
low_cpu_mem_usage: bool = True, device: Optional[str] = None,
dtype: Union[str, torch.dtype] = torch.float32, load_in_4bit: bool = False, load_in_8bit: bool = False
):
super().__init__()
self.encoder = T5EncoderModel.from_pretrained(
model_path, low_cpu_mem_usage=low_cpu_mem_usage, device_map=device,
torch_dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit,
).encoder
self.projection = nn.Sequential(
nn.Linear(self.encoder.config.d_model, context_dim, bias=False),
nn.LayerNorm(context_dim)
)
def forward(self, model_input):
embeddings = self.encoder(**model_input).last_hidden_state
context = self.projection(embeddings)
if 'attention_mask' in model_input:
context_mask = model_input['attention_mask']
context[context_mask == 0] = torch.zeros_like(context[context_mask == 0])
max_seq_length = context_mask.sum(-1).max() + 1
context = context[:, :max_seq_length]
context_mask = context_mask[:, :max_seq_length]
else:
context_mask = torch.ones(*embeddings.shape[:-1], dtype=torch.long, device=embeddings.device)
return context, context_mask
def get_condition_encoder(conf):
return T5TextConditionEncoder(**conf)