|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
from transformers import AutoTokenizer, Qwen2ForCausalLM |
|
|
|
from tok.mm_autoencoder import MMAutoEncoder |
|
|
|
|
|
@dataclass |
|
class T2IConfig: |
|
model_path: str = "ByteDance-Seed/Tar-1.5B" |
|
|
|
ar_path = None |
|
encoder_path: str = 'ta_tok.pth' |
|
decoder_path: str = 'vq_ds16_t2i.pt' |
|
|
|
device: str = "cuda:0" |
|
dtype: torch.dtype = torch.bfloat16 |
|
|
|
scale: int = 0 |
|
seq_len: int = 729 |
|
temperature: float = 1.0 |
|
top_p: float = 0.95 |
|
top_k: int = 1200 |
|
cfg_scale: float = 4.0 |
|
|
|
class TextToImageInference: |
|
def __init__(self, config: T2IConfig): |
|
self.config = config |
|
self.device = torch.device(config.device) |
|
self._load_models() |
|
|
|
def _load_models(self): |
|
self.model = Qwen2ForCausalLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device) |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) |
|
|
|
|
|
config = dict( |
|
ar_path_dict=self.config.ar_path, |
|
encoder_path=self.config.encoder_path, |
|
decoder_path=self.config.decoder_path, |
|
encoder_args={'input_type': 'rec'}, |
|
decoder_args={}, |
|
) |
|
self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device) |
|
for ar_model in self.visual_tokenizer.ar_model.values(): |
|
ar_model.cls_token_num = self.config.seq_len |
|
self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1 |
|
|
|
def generate_image(self, prompt, resolution, top_p, top_k, cfg_scale) -> Image.Image: |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
|
|
input_text = self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True) |
|
input_text += f"<im_start><S{self.config.scale}>" |
|
|
|
|
|
inputs = self.tokenizer(input_text, return_tensors="pt") |
|
gen_ids = self.model.generate( |
|
inputs.input_ids.to(self.device), |
|
max_new_tokens=self.config.seq_len, |
|
do_sample=True, |
|
temperature=self.config.temperature, |
|
top_p=top_p, |
|
top_k=top_k) |
|
|
|
|
|
gen_text = self.tokenizer.batch_decode(gen_ids)[0] |
|
gen_code = [int(x) for x in re.findall(r'<I(\d+)>', gen_text)] |
|
gen_code = gen_code[:self.config.seq_len] + [0] * max(0, self.config.seq_len - len(gen_code)) |
|
gen_code = torch.tensor(gen_code).unsqueeze(0).to(self.device) |
|
|
|
gen_tensor = self.visual_tokenizer.decode_from_encoder_indices( |
|
gen_code, |
|
{'cfg_scale': cfg_scale, 'resolution': resolution}, |
|
) |
|
gen_image = Image.fromarray(gen_tensor[0].numpy()) |
|
return gen_image |