Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from transformers import T5EncoderModel, T5Config | |
from .sd_text_encoder import SDTextEncoder | |
class FluxTextEncoder1(SDTextEncoder): | |
def __init__(self, vocab_size=49408): | |
super().__init__(vocab_size=vocab_size) | |
def forward(self, input_ids, clip_skip=2): | |
embeds = self.token_embedding(input_ids) + self.position_embeds | |
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) | |
for encoder_id, encoder in enumerate(self.encoders): | |
embeds = encoder(embeds, attn_mask=attn_mask) | |
if encoder_id + clip_skip == len(self.encoders): | |
hidden_states = embeds | |
embeds = self.final_layer_norm(embeds) | |
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)] | |
return embeds, pooled_embeds | |
def state_dict_converter(): | |
return FluxTextEncoder1StateDictConverter() | |
class FluxTextEncoder2(T5EncoderModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.eval() | |
def forward(self, input_ids): | |
outputs = super().forward(input_ids=input_ids) | |
prompt_emb = outputs.last_hidden_state | |
return prompt_emb | |
def state_dict_converter(): | |
return FluxTextEncoder2StateDictConverter() | |
class FluxTextEncoder1StateDictConverter: | |
def __init__(self): | |
pass | |
def from_diffusers(self, state_dict): | |
rename_dict = { | |
"text_model.embeddings.token_embedding.weight": "token_embedding.weight", | |
"text_model.embeddings.position_embedding.weight": "position_embeds", | |
"text_model.final_layer_norm.weight": "final_layer_norm.weight", | |
"text_model.final_layer_norm.bias": "final_layer_norm.bias" | |
} | |
attn_rename_dict = { | |
"self_attn.q_proj": "attn.to_q", | |
"self_attn.k_proj": "attn.to_k", | |
"self_attn.v_proj": "attn.to_v", | |
"self_attn.out_proj": "attn.to_out", | |
"layer_norm1": "layer_norm1", | |
"layer_norm2": "layer_norm2", | |
"mlp.fc1": "fc1", | |
"mlp.fc2": "fc2", | |
} | |
state_dict_ = {} | |
for name in state_dict: | |
if name in rename_dict: | |
param = state_dict[name] | |
if name == "text_model.embeddings.position_embedding.weight": | |
param = param.reshape((1, param.shape[0], param.shape[1])) | |
state_dict_[rename_dict[name]] = param | |
elif name.startswith("text_model.encoder.layers."): | |
param = state_dict[name] | |
names = name.split(".") | |
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] | |
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) | |
state_dict_[name_] = param | |
return state_dict_ | |
def from_civitai(self, state_dict): | |
return self.from_diffusers(state_dict) | |
class FluxTextEncoder2StateDictConverter(): | |
def __init__(self): | |
pass | |
def from_diffusers(self, state_dict): | |
state_dict_ = state_dict | |
return state_dict_ | |
def from_civitai(self, state_dict): | |
return self.from_diffusers(state_dict) | |