File size: 3,635 Bytes
3c55139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22ff2b2
3c55139
 
 
 
3b97045
3c55139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b97045
3c55139
 
 
 
 
 
 
 
 
 
3b97045
3c55139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import re
from dataclasses import dataclass

import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import AutoTokenizer, Qwen2ForCausalLM

from tok.mm_autoencoder import MMAutoEncoder


@dataclass
class T2IConfig:
    model_path: str = "csuhan/Tar-1.5B"
    # visual tokenizer config
    ar_path: str = 'ar_dtok_lp_256px.pth'
    encoder_path: str = 'ta_tok.pth'
    decoder_path: str = 'vq_ds16_t2i.pt'
    
    device: str = "cuda:0"
    dtype: torch.dtype = torch.bfloat16
    # generation parameters
    scale: int = 0  # choose from [0, 1, 2]
    seq_len: int = 729  # choose from [729, 169, 81]
    temperature: float = 1.0
    top_p: float = 0.95
    top_k: int = 1200
    cfg_scale: float = 4.0

class TextToImageInference:
    def __init__(self, config: T2IConfig):
        self.config = config
        self.device = torch.device(config.device)
        self._load_models()
        
    def _load_models(self):
        self.model = Qwen2ForCausalLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
        
        # Initialize visual tokenizer
        config = dict(
            ar_path=self.config.ar_path,
            encoder_path=self.config.encoder_path,
            decoder_path=self.config.decoder_path,
            encoder_args={'input_type': 'rec'},
            decoder_args={},
        )
        self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device)
        self.visual_tokenizer.ar_model.cls_token_num = self.config.seq_len
        self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1

    def generate_image(self, prompt: str) -> Image.Image:
        # Prepare prompt
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
        
        input_text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True)
        input_text += f"<im_start><S{self.config.scale}>"
        
        # Generate tokens
        inputs = self.tokenizer(input_text, return_tensors="pt")
        gen_ids = self.model.generate(
            inputs.input_ids.to(self.device),
            max_new_tokens=self.config.seq_len,
            do_sample=True,
            temperature=self.config.temperature,
            top_p=self.config.top_p,
            top_k=self.config.top_k)
        
        # Process generated tokens
        gen_text = self.tokenizer.batch_decode(gen_ids)[0]
        gen_code = [int(x) for x in re.findall(r'<I(\d+)>', gen_text)]
        gen_code = gen_code[:self.config.seq_len] + [0] * max(0, self.config.seq_len - len(gen_code))
        gen_code = torch.tensor(gen_code).unsqueeze(0).to(self.device)
        
        gen_tensor = self.visual_tokenizer.decode_from_encoder_indices(
            gen_code, 
            {'cfg_scale': self.config.cfg_scale}
        )
        gen_image = Image.fromarray(gen_tensor[0].numpy())
        return gen_image

def main():
    config = T2IConfig()
    config.ar_path = hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth")
    config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
    config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
    inference = TextToImageInference(config)
    
    prompt = "A photo of a macaw"
    image = inference.generate_image(prompt)
    image.save("generated_image.png")

if __name__ == "__main__":
    main()