Update raven_modeling_minimal.py
Browse files- raven_modeling_minimal.py +17 -13
raven_modeling_minimal.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
"""Minimal modeling.py file for HF compatibility and funny zero-shot experiments.
|
2 |
|
3 |
import torch
|
4 |
import math
|
@@ -289,7 +289,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
289 |
attention_mask: Optional[torch.Tensor] = None,
|
290 |
position_ids: Optional[torch.Tensor] = None,
|
291 |
labels: Optional[torch.Tensor] = None,
|
292 |
-
|
293 |
past_key_values: Optional[Cache] = None,
|
294 |
output_details: dict = {
|
295 |
"return_logits": True,
|
@@ -302,12 +302,12 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
302 |
cache_position: Optional[torch.Tensor] = None,
|
303 |
**kwargs,
|
304 |
) -> CausalLMOutputRecurrentLatents:
|
305 |
-
|
306 |
if position_ids is None and cache_position is None:
|
307 |
freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
|
308 |
elif position_ids is not None:
|
309 |
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
|
310 |
-
elif cache_position is not None:
|
311 |
freqs_cis = self.freqs_cis[:, cache_position]
|
312 |
|
313 |
if input_embeds is None:
|
@@ -331,7 +331,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
331 |
block_idx,
|
332 |
attention_mask,
|
333 |
past_key_values,
|
334 |
-
|
335 |
)
|
336 |
latent_states = x.clone().detach()
|
337 |
|
@@ -371,16 +371,16 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
371 |
block_idx,
|
372 |
mask,
|
373 |
past_key_values: Optional[Cache] = None,
|
374 |
-
|
375 |
):
|
376 |
x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
|
377 |
|
378 |
-
if
|
379 |
num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
|
380 |
-
elif len(
|
381 |
-
num_steps_no_grad, num_steps_with_grad =
|
382 |
else:
|
383 |
-
num_steps_no_grad, num_steps_with_grad =
|
384 |
|
385 |
with torch.no_grad():
|
386 |
# ultra annoying in ddp due to
|
@@ -421,13 +421,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
421 |
|
422 |
return n.to(dtype=torch.long), k.to(dtype=torch.long)
|
423 |
|
424 |
-
def initialize_state(self, input_embeds):
|
425 |
x = torch.randn_like(input_embeds)
|
426 |
std = self.config.init_values["std"]
|
427 |
torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
428 |
if self.emb_scale != 1:
|
429 |
x = x * self.emb_scale
|
430 |
-
return x
|
431 |
|
432 |
def prepare_inputs_for_generation(
|
433 |
self,
|
@@ -442,7 +442,11 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
442 |
model_inputs["cache_position"] = cache_position
|
443 |
current_input_length = input_ids.shape[1]
|
444 |
if past_key_values is not None:
|
445 |
-
|
|
|
|
|
|
|
|
|
446 |
input_ids = input_ids[:, cache_position] # type: ignore
|
447 |
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
|
448 |
|
|
|
1 |
+
"""Minimal modeling.py file for HF compatibility and funny zero-shot experiments. Usability for finetuning not guaranteed"""
|
2 |
|
3 |
import torch
|
4 |
import math
|
|
|
289 |
attention_mask: Optional[torch.Tensor] = None,
|
290 |
position_ids: Optional[torch.Tensor] = None,
|
291 |
labels: Optional[torch.Tensor] = None,
|
292 |
+
num_steps: Optional[torch.Tensor] = None,
|
293 |
past_key_values: Optional[Cache] = None,
|
294 |
output_details: dict = {
|
295 |
"return_logits": True,
|
|
|
302 |
cache_position: Optional[torch.Tensor] = None,
|
303 |
**kwargs,
|
304 |
) -> CausalLMOutputRecurrentLatents:
|
305 |
+
# Support multiple position formats:
|
306 |
if position_ids is None and cache_position is None:
|
307 |
freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
|
308 |
elif position_ids is not None:
|
309 |
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
|
310 |
+
elif cache_position is not None:
|
311 |
freqs_cis = self.freqs_cis[:, cache_position]
|
312 |
|
313 |
if input_embeds is None:
|
|
|
331 |
block_idx,
|
332 |
attention_mask,
|
333 |
past_key_values,
|
334 |
+
num_steps,
|
335 |
)
|
336 |
latent_states = x.clone().detach()
|
337 |
|
|
|
371 |
block_idx,
|
372 |
mask,
|
373 |
past_key_values: Optional[Cache] = None,
|
374 |
+
num_steps: Optional[torch.Tensor] = None,
|
375 |
):
|
376 |
x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
|
377 |
|
378 |
+
if num_steps is None:
|
379 |
num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
|
380 |
+
elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
|
381 |
+
num_steps_no_grad, num_steps_with_grad = num_steps
|
382 |
else:
|
383 |
+
num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0)
|
384 |
|
385 |
with torch.no_grad():
|
386 |
# ultra annoying in ddp due to
|
|
|
421 |
|
422 |
return n.to(dtype=torch.long), k.to(dtype=torch.long)
|
423 |
|
424 |
+
def initialize_state(self, input_embeds, deterministic: bool = False):
|
425 |
x = torch.randn_like(input_embeds)
|
426 |
std = self.config.init_values["std"]
|
427 |
torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
428 |
if self.emb_scale != 1:
|
429 |
x = x * self.emb_scale
|
430 |
+
return x if not deterministic else x.zero_()
|
431 |
|
432 |
def prepare_inputs_for_generation(
|
433 |
self,
|
|
|
442 |
model_inputs["cache_position"] = cache_position
|
443 |
current_input_length = input_ids.shape[1]
|
444 |
if past_key_values is not None:
|
445 |
+
if type(past_key_values) == DynamicCache:
|
446 |
+
# Need to use custom cache, detect and replace HF dynamic cache if generate injects it
|
447 |
+
assert past_key_values.get_seq_length() == 0
|
448 |
+
past_key_values = HuginnDynamicCache()
|
449 |
+
model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
|
450 |
input_ids = input_ids[:, cache_position] # type: ignore
|
451 |
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
|
452 |
|