ACE-Plus / models /embedder.py
chaojiemao's picture
modify ace flux
0d206f3
raw
history blame
14.2 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import warnings
from contextlib import nullcontext
import torch
import torch.nn.functional as F
import torch.utils.dlpack
import transformers
from scepter.modules.model.embedder.base_embedder import BaseEmbedder
from scepter.modules.model.registry import EMBEDDERS
from scepter.modules.model.tokenizer.tokenizer_component import (
basic_clean, canonicalize, heavy_clean, whitespace_clean)
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.distribute import we
from scepter.modules.utils.file_system import FS
try:
from transformers import AutoTokenizer, T5EncoderModel
except Exception as e:
warnings.warn(
f'Import transformers error, please deal with this problem: {e}')
@EMBEDDERS.register_class()
class ACETextEmbedder(BaseEmbedder):
"""
Uses the OpenCLIP transformer encoder for text
"""
"""
Uses the OpenCLIP transformer encoder for text
"""
para_dict = {
'PRETRAINED_MODEL': {
'value':
'google/umt5-small',
'description':
'Pretrained Model for umt5, modelcard path or local path.'
},
'TOKENIZER_PATH': {
'value': 'google/umt5-small',
'description':
'Tokenizer Path for umt5, modelcard path or local path.'
},
'FREEZE': {
'value': True,
'description': ''
},
'USE_GRAD': {
'value': False,
'description': 'Compute grad or not.'
},
'CLEAN': {
'value':
'whitespace',
'description':
'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
},
'LAYER': {
'value': 'last',
'description': ''
},
'LEGACY': {
'value':
True,
'description':
'Whether use legacy returnd feature or not ,default True.'
}
}
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
pretrained_path = cfg.get('PRETRAINED_MODEL', None)
self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
assert pretrained_path
with FS.get_dir_to_local_dir(pretrained_path,
wait_finish=True) as local_path:
self.model = T5EncoderModel.from_pretrained(
local_path,
torch_dtype=getattr(
torch,
'float' if self.t5_dtype == 'float32' else self.t5_dtype))
tokenizer_path = cfg.get('TOKENIZER_PATH', None)
self.length = cfg.get('LENGTH', 77)
self.use_grad = cfg.get('USE_GRAD', False)
self.clean = cfg.get('CLEAN', 'whitespace')
self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
if tokenizer_path:
self.tokenize_kargs = {'return_tensors': 'pt'}
with FS.get_dir_to_local_dir(tokenizer_path,
wait_finish=True) as local_path:
if self.added_identifier is not None and isinstance(
self.added_identifier, list):
self.tokenizer = AutoTokenizer.from_pretrained(local_path)
else:
self.tokenizer = AutoTokenizer.from_pretrained(local_path)
if self.length is not None:
self.tokenize_kargs.update({
'padding': 'max_length',
'truncation': True,
'max_length': self.length
})
self.eos_token = self.tokenizer(
self.tokenizer.eos_token)['input_ids'][0]
else:
self.tokenizer = None
self.tokenize_kargs = {}
self.use_grad = cfg.get('USE_GRAD', False)
self.clean = cfg.get('CLEAN', 'whitespace')
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
# encode && encode_text
def forward(self, tokens, return_mask=False, use_mask=True):
# tokenization
embedding_context = nullcontext if self.use_grad else torch.no_grad
with embedding_context():
if use_mask:
x = self.model(tokens.input_ids.to(we.device_id),
tokens.attention_mask.to(we.device_id))
else:
x = self.model(tokens.input_ids.to(we.device_id))
x = x.last_hidden_state
if return_mask:
return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
else:
return x.detach() + 0.0, None
def _clean(self, text):
if self.clean == 'whitespace':
text = whitespace_clean(basic_clean(text))
elif self.clean == 'lower':
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
elif self.clean == 'heavy':
text = heavy_clean(basic_clean(text))
return text
def encode(self, text, return_mask=False, use_mask=True):
if isinstance(text, str):
text = [text]
if self.clean:
text = [self._clean(u) for u in text]
assert self.tokenizer is not None
cont, mask = [], []
with torch.autocast(device_type='cuda',
enabled=self.t5_dtype in ('float16', 'bfloat16'),
dtype=getattr(torch, self.t5_dtype)):
for tt in text:
tokens = self.tokenizer([tt], **self.tokenize_kargs)
one_cont, one_mask = self(tokens,
return_mask=return_mask,
use_mask=use_mask)
cont.append(one_cont)
mask.append(one_mask)
if return_mask:
return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
else:
return torch.cat(cont, dim=0)
def encode_list(self, text_list, return_mask=True):
cont_list = []
mask_list = []
for pp in text_list:
cont, cont_mask = self.encode(pp, return_mask=return_mask)
cont_list.append(cont)
mask_list.append(cont_mask)
if return_mask:
return cont_list, mask_list
else:
return cont_list
@staticmethod
def get_config_template():
return dict_to_yaml('MODELS',
__class__.__name__,
ACETextEmbedder.para_dict,
set_name=True)
@EMBEDDERS.register_class()
class ACEHFEmbedder(BaseEmbedder):
para_dict = {
"HF_MODEL_CLS": {
"value": None,
"description": "huggingface cls in transfomer"
},
"MODEL_PATH": {
"value": None,
"description": "model folder path"
},
"HF_TOKENIZER_CLS": {
"value": None,
"description": "huggingface cls in transfomer"
},
"TOKENIZER_PATH": {
"value": None,
"description": "tokenizer folder path"
},
"MAX_LENGTH": {
"value": 77,
"description": "max length of input"
},
"OUTPUT_KEY": {
"value": "last_hidden_state",
"description": "output key"
},
"D_TYPE": {
"value": "float",
"description": "dtype"
},
"BATCH_INFER": {
"value": False,
"description": "batch infer"
}
}
para_dict.update(BaseEmbedder.para_dict)
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
hf_model_cls = cfg.get('HF_MODEL_CLS', None)
model_path = cfg.get("MODEL_PATH", None)
hf_tokenizer_cls = cfg.get('HF_TOKENIZER_CLS', None)
tokenizer_path = cfg.get('TOKENIZER_PATH', None)
self.max_length = cfg.get('MAX_LENGTH', 77)
self.output_key = cfg.get("OUTPUT_KEY", "last_hidden_state")
self.d_type = cfg.get("D_TYPE", "float")
self.clean = cfg.get("CLEAN", "whitespace")
self.batch_infer = cfg.get("BATCH_INFER", False)
self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
torch_dtype = getattr(torch, self.d_type)
assert hf_model_cls is not None and hf_tokenizer_cls is not None
assert model_path is not None and tokenizer_path is not None
with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path:
self.tokenizer = getattr(transformers, hf_tokenizer_cls).from_pretrained(local_path,
max_length = self.max_length,
torch_dtype = torch_dtype,
additional_special_tokens=self.added_identifier)
with FS.get_dir_to_local_dir(model_path, wait_finish=True) as local_path:
self.hf_module = getattr(transformers, hf_model_cls).from_pretrained(local_path, torch_dtype = torch_dtype)
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str], return_mask = False):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
if return_mask:
return outputs[self.output_key], batch_encoding['attention_mask'].to(self.hf_module.device)
else:
return outputs[self.output_key], None
def encode(self, text, return_mask = False):
if isinstance(text, str):
text = [text]
if self.clean:
text = [self._clean(u) for u in text]
if not self.batch_infer:
cont, mask = [], []
for tt in text:
one_cont, one_mask = self([tt], return_mask=return_mask)
cont.append(one_cont)
mask.append(one_mask)
if return_mask:
return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
else:
return torch.cat(cont, dim=0)
else:
ret_data = self(text, return_mask = return_mask)
if return_mask:
return ret_data
else:
return ret_data[0]
def encode_list(self, text_list, return_mask=True):
cont_list = []
mask_list = []
for pp in text_list:
cont = self.encode(pp, return_mask=return_mask)
cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
mask_list.append(cont[1]) if return_mask else mask_list.append(None)
if return_mask:
return cont_list, mask_list
else:
return cont_list
def encode_list_of_list(self, text_list, return_mask=True):
cont_list = []
mask_list = []
for pp in text_list:
cont = self.encode_list(pp, return_mask=return_mask)
cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
mask_list.append(cont[1]) if return_mask else mask_list.append(None)
if return_mask:
return cont_list, mask_list
else:
return cont_list
def _clean(self, text):
if self.clean == 'whitespace':
text = whitespace_clean(basic_clean(text))
elif self.clean == 'lower':
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
return text
@staticmethod
def get_config_template():
return dict_to_yaml('EMBEDDER',
__class__.__name__,
ACEHFEmbedder.para_dict,
set_name=True)
@EMBEDDERS.register_class()
class T5ACEPlusClipFluxEmbedder(BaseEmbedder):
"""
Uses the OpenCLIP transformer encoder for text
"""
para_dict = {
'T5_MODEL': {},
'CLIP_MODEL': {}
}
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
self.t5_model = EMBEDDERS.build(cfg.T5_MODEL, logger=logger)
self.clip_model = EMBEDDERS.build(cfg.CLIP_MODEL, logger=logger)
def encode(self, text, return_mask = False):
t5_embeds = self.t5_model.encode(text, return_mask = return_mask)
clip_embeds = self.clip_model.encode(text, return_mask = return_mask)
# change embedding strategy here
return {
'context': t5_embeds,
'y': clip_embeds,
}
def encode_list(self, text, return_mask = False):
t5_embeds = self.t5_model.encode_list(text, return_mask = return_mask)
clip_embeds = self.clip_model.encode_list(text, return_mask = return_mask)
# change embedding strategy here
return {
'context': t5_embeds,
'y': clip_embeds,
}
def encode_list_of_list(self, text, return_mask = False):
t5_embeds = self.t5_model.encode_list_of_list(text, return_mask = return_mask)
clip_embeds = self.clip_model.encode_list_of_list(text, return_mask = return_mask)
# change embedding strategy here
return {
'context': t5_embeds,
'y': clip_embeds,
}
@staticmethod
def get_config_template():
return dict_to_yaml('EMBEDDER',
__class__.__name__,
T5ACEPlusClipFluxEmbedder.para_dict,
set_name=True)