Update raven_modeling_minimal.py
Browse files
raven_modeling_minimal.py
CHANGED
@@ -302,12 +302,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
302 |
cache_position: Optional[torch.Tensor] = None,
|
303 |
**kwargs,
|
304 |
) -> CausalLMOutputRecurrentLatents:
|
|
|
305 |
if position_ids is None and cache_position is None:
|
306 |
freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
|
307 |
elif position_ids is not None:
|
308 |
-
freqs_cis = self.freqs_cis.index_select(1, position_ids)
|
309 |
elif cache_position is not None: # support HF format
|
310 |
-
freqs_cis = self.freqs_cis[:, cache_position
|
311 |
|
312 |
if input_embeds is None:
|
313 |
input_embeds = self.transformer.wte(input_ids)
|
@@ -445,10 +446,11 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
445 |
input_ids = input_ids[:, cache_position] # type: ignore
|
446 |
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
|
447 |
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
|
|
452 |
|
453 |
# forward all other entries
|
454 |
for key, value in kwargs.items():
|
|
|
302 |
cache_position: Optional[torch.Tensor] = None,
|
303 |
**kwargs,
|
304 |
) -> CausalLMOutputRecurrentLatents:
|
305 |
+
|
306 |
if position_ids is None and cache_position is None:
|
307 |
freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
|
308 |
elif position_ids is not None:
|
309 |
+
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
|
310 |
elif cache_position is not None: # support HF format
|
311 |
+
freqs_cis = self.freqs_cis[:, cache_position]
|
312 |
|
313 |
if input_embeds is None:
|
314 |
input_embeds = self.transformer.wte(input_ids)
|
|
|
446 |
input_ids = input_ids[:, cache_position] # type: ignore
|
447 |
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
|
448 |
|
449 |
+
if cache_position is None:
|
450 |
+
position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
|
451 |
+
model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
|
452 |
+
memory_format=torch.contiguous_format
|
453 |
+
) # some form of position_ids is a critical argument for the model to correctly apply rope!
|
454 |
|
455 |
# forward all other entries
|
456 |
for key, value in kwargs.items():
|