JonasGeiping commited on
Commit
fbd6377
·
verified ·
1 Parent(s): f39aa4c

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +8 -6
raven_modeling_minimal.py CHANGED
@@ -302,12 +302,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
302
  cache_position: Optional[torch.Tensor] = None,
303
  **kwargs,
304
  ) -> CausalLMOutputRecurrentLatents:
 
305
  if position_ids is None and cache_position is None:
306
  freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
307
  elif position_ids is not None:
308
- freqs_cis = self.freqs_cis.index_select(1, position_ids)
309
  elif cache_position is not None: # support HF format
310
- freqs_cis = self.freqs_cis[:, cache_position : cache_position + 1]
311
 
312
  if input_embeds is None:
313
  input_embeds = self.transformer.wte(input_ids)
@@ -445,10 +446,11 @@ class RavenForCausalLM(RavenPreTrainedModel):
445
  input_ids = input_ids[:, cache_position] # type: ignore
446
  model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
447
 
448
- position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
449
- model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
450
- memory_format=torch.contiguous_format
451
- ) # positions_ids is a critical argument for the model to correctly apply rope!
 
452
 
453
  # forward all other entries
454
  for key, value in kwargs.items():
 
302
  cache_position: Optional[torch.Tensor] = None,
303
  **kwargs,
304
  ) -> CausalLMOutputRecurrentLatents:
305
+
306
  if position_ids is None and cache_position is None:
307
  freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
308
  elif position_ids is not None:
309
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
310
  elif cache_position is not None: # support HF format
311
+ freqs_cis = self.freqs_cis[:, cache_position]
312
 
313
  if input_embeds is None:
314
  input_embeds = self.transformer.wte(input_ids)
 
446
  input_ids = input_ids[:, cache_position] # type: ignore
447
  model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
448
 
449
+ if cache_position is None:
450
+ position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
451
+ model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
452
+ memory_format=torch.contiguous_format
453
+ ) # some form of position_ids is a critical argument for the model to correctly apply rope!
454
 
455
  # forward all other entries
456
  for key, value in kwargs.items():