Spaces:
Sleeping
Sleeping
File size: 3,635 Bytes
3c55139 22ff2b2 3c55139 3b97045 3c55139 3b97045 3c55139 3b97045 3c55139 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import re
from dataclasses import dataclass
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import AutoTokenizer, Qwen2ForCausalLM
from tok.mm_autoencoder import MMAutoEncoder
@dataclass
class T2IConfig:
model_path: str = "csuhan/Tar-1.5B"
# visual tokenizer config
ar_path: str = 'ar_dtok_lp_256px.pth'
encoder_path: str = 'ta_tok.pth'
decoder_path: str = 'vq_ds16_t2i.pt'
device: str = "cuda:0"
dtype: torch.dtype = torch.bfloat16
# generation parameters
scale: int = 0 # choose from [0, 1, 2]
seq_len: int = 729 # choose from [729, 169, 81]
temperature: float = 1.0
top_p: float = 0.95
top_k: int = 1200
cfg_scale: float = 4.0
class TextToImageInference:
def __init__(self, config: T2IConfig):
self.config = config
self.device = torch.device(config.device)
self._load_models()
def _load_models(self):
self.model = Qwen2ForCausalLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
# Initialize visual tokenizer
config = dict(
ar_path=self.config.ar_path,
encoder_path=self.config.encoder_path,
decoder_path=self.config.decoder_path,
encoder_args={'input_type': 'rec'},
decoder_args={},
)
self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device)
self.visual_tokenizer.ar_model.cls_token_num = self.config.seq_len
self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1
def generate_image(self, prompt: str) -> Image.Image:
# Prepare prompt
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
input_text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True)
input_text += f"<im_start><S{self.config.scale}>"
# Generate tokens
inputs = self.tokenizer(input_text, return_tensors="pt")
gen_ids = self.model.generate(
inputs.input_ids.to(self.device),
max_new_tokens=self.config.seq_len,
do_sample=True,
temperature=self.config.temperature,
top_p=self.config.top_p,
top_k=self.config.top_k)
# Process generated tokens
gen_text = self.tokenizer.batch_decode(gen_ids)[0]
gen_code = [int(x) for x in re.findall(r'<I(\d+)>', gen_text)]
gen_code = gen_code[:self.config.seq_len] + [0] * max(0, self.config.seq_len - len(gen_code))
gen_code = torch.tensor(gen_code).unsqueeze(0).to(self.device)
gen_tensor = self.visual_tokenizer.decode_from_encoder_indices(
gen_code,
{'cfg_scale': self.config.cfg_scale}
)
gen_image = Image.fromarray(gen_tensor[0].numpy())
return gen_image
def main():
config = T2IConfig()
config.ar_path = hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth")
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
inference = TextToImageInference(config)
prompt = "A photo of a macaw"
image = inference.generate_image(prompt)
image.save("generated_image.png")
if __name__ == "__main__":
main() |