chong.zhang commited on
Commit
8cf7229
·
1 Parent(s): a3a0c9b
inspiremusic/transformer/qwen_encoder.py CHANGED
@@ -163,7 +163,7 @@ class QwenInputOnlyEncoder(nn.Module):
163
  else:
164
  self.dtype = torch.float32
165
  from transformers import Qwen2ForCausalLM
166
- model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="auto", attn_implementation="flash_attention_2", torch_dtype=self.dtype)
167
  self.embed = model.model.embed_tokens
168
  for p in self.embed.parameters():
169
  p.requires_grad = False
 
163
  else:
164
  self.dtype = torch.float32
165
  from transformers import Qwen2ForCausalLM
166
+ model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu")
167
  self.embed = model.model.embed_tokens
168
  for p in self.embed.parameters():
169
  p.requires_grad = False