Jiaming Han commited on
Commit
3b97045
·
1 Parent(s): 22ff2b2
Files changed (1) hide show
  1. t2i_inference.py +3 -2
t2i_inference.py CHANGED
@@ -18,6 +18,7 @@ class T2IConfig:
18
  decoder_path: str = 'vq_ds16_t2i.pt'
19
 
20
  device: str = "cuda:0"
 
21
  # generation parameters
22
  scale: int = 0 # choose from [0, 1, 2]
23
  seq_len: int = 729 # choose from [729, 169, 81]
@@ -33,7 +34,7 @@ class TextToImageInference:
33
  self._load_models()
34
 
35
  def _load_models(self):
36
- self.model = Qwen2ForCausalLM.from_pretrained(self.config.model_path).to(self.device)
37
  self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
38
 
39
  # Initialize visual tokenizer
@@ -44,7 +45,7 @@ class TextToImageInference:
44
  encoder_args={'input_type': 'rec'},
45
  decoder_args={},
46
  )
47
- self.visual_tokenizer = MMAutoEncoder(**config).eval().to(self.device)
48
  self.visual_tokenizer.ar_model.cls_token_num = self.config.seq_len
49
  self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1
50
 
 
18
  decoder_path: str = 'vq_ds16_t2i.pt'
19
 
20
  device: str = "cuda:0"
21
+ dtype: torch.dtype = torch.bfloat16
22
  # generation parameters
23
  scale: int = 0 # choose from [0, 1, 2]
24
  seq_len: int = 729 # choose from [729, 169, 81]
 
34
  self._load_models()
35
 
36
  def _load_models(self):
37
+ self.model = Qwen2ForCausalLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device)
38
  self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
39
 
40
  # Initialize visual tokenizer
 
45
  encoder_args={'input_type': 'rec'},
46
  decoder_args={},
47
  )
48
+ self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device)
49
  self.visual_tokenizer.ar_model.cls_token_num = self.config.seq_len
50
  self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1
51