qhduan commited on
Commit
f537b29
1 Parent(s): f309276

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. modeling_aquila.py +7 -5
config.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
- "_name_or_path": "../../models/aquila-7b-llama/",
3
  "architectures": [
4
  "AquilaForCausalLM"
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "modeling_aquila.LlamaConfig",
8
  "AutoModel": "modeling_aquila.LlamaModel",
9
- "AutoModelForCausalLM": "modeling_aquila.AquilaForCausalLM"
10
  },
11
  "bos_token_id": 1,
12
  "eos_token_id": 2,
 
1
  {
2
+ "_name_or_path": "qhduan/aquila-7b",
3
  "architectures": [
4
  "AquilaForCausalLM"
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "modeling_aquila.LlamaConfig",
8
  "AutoModel": "modeling_aquila.LlamaModel",
9
+ "AutoModelForCausalLM": "modeling_aquila.LlamaForCausalLM"
10
  },
11
  "bos_token_id": 1,
12
  "eos_token_id": 2,
modeling_aquila.py CHANGED
@@ -250,12 +250,14 @@ class LlamaAttention(nn.Module):
250
  key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
251
  value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
252
  self.freqs_cis = self.freqs_cis.to(hidden_states.device)
253
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, freqs_cis=self.freqs_cis[:query_states.shape[1]])
254
- query_states = query_states.transpose(1, 2)
255
- key_states = key_states.transpose(1, 2)
256
- kv_seq_len = key_states.shape[-2]
257
  if past_key_value is not None:
258
  kv_seq_len += past_key_value[0].shape[-2]
 
 
 
 
 
259
 
260
  # query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
261
  # key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@@ -695,7 +697,7 @@ class LlamaModel(LlamaPreTrainedModel):
695
  )
696
 
697
 
698
- class AquilaForCausalLM(LlamaPreTrainedModel):
699
  def __init__(self, config):
700
  super().__init__(config)
701
  self.model = LlamaModel(config)
 
250
  key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
251
  value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
252
  self.freqs_cis = self.freqs_cis.to(hidden_states.device)
253
+ kv_seq_len = key_states.shape[-3]
 
 
 
254
  if past_key_value is not None:
255
  kv_seq_len += past_key_value[0].shape[-2]
256
+ query_states, key_states = apply_rotary_pos_emb(
257
+ query_states, key_states, freqs_cis=self.freqs_cis[kv_seq_len-query_states.shape[1]:kv_seq_len]
258
+ )
259
+ query_states = query_states.transpose(1, 2)
260
+ key_states = key_states.transpose(1, 2)
261
 
262
  # query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
263
  # key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
697
  )
698
 
699
 
700
+ class LlamaForCausalLM(LlamaPreTrainedModel):
701
  def __init__(self, config):
702
  super().__init__(config)
703
  self.model = LlamaModel(config)