Update raven_modeling_minimal.py
Browse files- raven_modeling_minimal.py +26 -10
raven_modeling_minimal.py
CHANGED
@@ -11,7 +11,7 @@ from .raven_config_minimal import RavenConfig
|
|
11 |
from transformers.cache_utils import Cache, DynamicCache
|
12 |
|
13 |
###################### Huggingface Glue code I ##################################################################
|
14 |
-
from transformers import PreTrainedModel
|
15 |
from transformers.utils import ModelOutput
|
16 |
from transformers.generation.utils import GenerateDecoderOnlyOutput
|
17 |
|
@@ -32,7 +32,8 @@ class RavenPreTrainedModel(PreTrainedModel):
|
|
32 |
_supports_static_cache = False
|
33 |
|
34 |
def _init_weights(self, module):
|
35 |
-
|
|
|
36 |
|
37 |
|
38 |
@dataclass
|
@@ -309,7 +310,7 @@ class SandwichBlock(torch.nn.Module):
|
|
309 |
return x, attn_map
|
310 |
|
311 |
|
312 |
-
class RavenForCausalLM(RavenPreTrainedModel):
|
313 |
def __init__(
|
314 |
self,
|
315 |
config: RavenConfig,
|
@@ -367,7 +368,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
367 |
"return_latents": True,
|
368 |
"return_attention": False,
|
369 |
"return_head": False,
|
370 |
-
"return_stats":
|
371 |
},
|
372 |
use_cache: bool = False,
|
373 |
cache_position: Optional[torch.Tensor] = None,
|
@@ -395,7 +396,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
395 |
# Non-recurrent prelude
|
396 |
for block_idx, block in enumerate(self.transformer.prelude):
|
397 |
input_embeds, attn_map = block(
|
398 |
-
input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
|
399 |
)
|
400 |
attn_maps[block_idx] = attn_map
|
401 |
|
@@ -409,12 +410,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
409 |
past_key_values,
|
410 |
num_steps,
|
411 |
attn_maps,
|
|
|
412 |
)
|
413 |
latent_states = x.clone().detach()
|
414 |
|
415 |
# Coda layers
|
416 |
for block_idx, block in enumerate(self.transformer.coda, start=1):
|
417 |
-
x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn)
|
418 |
attn_maps[-block_idx] = attn_map
|
419 |
x = self.transformer.ln_f(x)
|
420 |
|
@@ -451,6 +453,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
451 |
past_key_values: Optional[Cache] = None,
|
452 |
num_steps: Optional[torch.Tensor] = None,
|
453 |
attn_maps: dict = {},
|
|
|
454 |
):
|
455 |
x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
|
456 |
if num_steps is None:
|
@@ -468,13 +471,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
468 |
for step in range(num_steps_no_grad):
|
469 |
xk = x
|
470 |
x, block_idx, attn_maps = self.core_block_forward(
|
471 |
-
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
|
472 |
)
|
473 |
|
474 |
for step in range(num_steps_with_grad):
|
475 |
xk = x
|
476 |
x, block_idx, attn_maps = self.core_block_forward(
|
477 |
-
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
|
478 |
)
|
479 |
return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
|
480 |
|
@@ -487,10 +490,11 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
487 |
past_key_values,
|
488 |
block_idx: Union[torch.Tensor, int],
|
489 |
attn_maps: dict = {},
|
|
|
490 |
):
|
491 |
x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
|
492 |
for idx, block in enumerate(self.transformer.core_block, start=1):
|
493 |
-
x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=
|
494 |
attn_maps[block_idx + idx] = attn_map
|
495 |
return x, block_idx + idx, attn_maps
|
496 |
|
@@ -623,7 +627,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
623 |
model_inputs["cache_position"] = cache_position
|
624 |
current_input_length = input_ids.shape[1]
|
625 |
if past_key_values is not None:
|
626 |
-
if type(past_key_values)
|
627 |
# Need to use custom cache, detect and replace HF dynamic cache if generate injects it
|
628 |
assert past_key_values.get_seq_length() == 0
|
629 |
past_key_values = HuginnDynamicCache()
|
@@ -643,6 +647,18 @@ class RavenForCausalLM(RavenPreTrainedModel):
|
|
643 |
model_inputs[key] = value
|
644 |
return model_inputs
|
645 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
@torch.no_grad()
|
647 |
def generate_minimal(
|
648 |
self,
|
|
|
11 |
from transformers.cache_utils import Cache, DynamicCache
|
12 |
|
13 |
###################### Huggingface Glue code I ##################################################################
|
14 |
+
from transformers import PreTrainedModel, GenerationMixin
|
15 |
from transformers.utils import ModelOutput
|
16 |
from transformers.generation.utils import GenerateDecoderOnlyOutput
|
17 |
|
|
|
32 |
_supports_static_cache = False
|
33 |
|
34 |
def _init_weights(self, module):
|
35 |
+
if not torch.rand((1,)).is_meta:
|
36 |
+
print("Random Initialization not implemented.")
|
37 |
|
38 |
|
39 |
@dataclass
|
|
|
310 |
return x, attn_map
|
311 |
|
312 |
|
313 |
+
class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
314 |
def __init__(
|
315 |
self,
|
316 |
config: RavenConfig,
|
|
|
368 |
"return_latents": True,
|
369 |
"return_attention": False,
|
370 |
"return_head": False,
|
371 |
+
"return_stats": False,
|
372 |
},
|
373 |
use_cache: bool = False,
|
374 |
cache_position: Optional[torch.Tensor] = None,
|
|
|
396 |
# Non-recurrent prelude
|
397 |
for block_idx, block in enumerate(self.transformer.prelude):
|
398 |
input_embeds, attn_map = block(
|
399 |
+
input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn=return_attn
|
400 |
)
|
401 |
attn_maps[block_idx] = attn_map
|
402 |
|
|
|
410 |
past_key_values,
|
411 |
num_steps,
|
412 |
attn_maps,
|
413 |
+
return_attn=return_attn,
|
414 |
)
|
415 |
latent_states = x.clone().detach()
|
416 |
|
417 |
# Coda layers
|
418 |
for block_idx, block in enumerate(self.transformer.coda, start=1):
|
419 |
+
x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn=return_attn)
|
420 |
attn_maps[-block_idx] = attn_map
|
421 |
x = self.transformer.ln_f(x)
|
422 |
|
|
|
453 |
past_key_values: Optional[Cache] = None,
|
454 |
num_steps: Optional[torch.Tensor] = None,
|
455 |
attn_maps: dict = {},
|
456 |
+
return_attn: bool = False,
|
457 |
):
|
458 |
x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
|
459 |
if num_steps is None:
|
|
|
471 |
for step in range(num_steps_no_grad):
|
472 |
xk = x
|
473 |
x, block_idx, attn_maps = self.core_block_forward(
|
474 |
+
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
|
475 |
)
|
476 |
|
477 |
for step in range(num_steps_with_grad):
|
478 |
xk = x
|
479 |
x, block_idx, attn_maps = self.core_block_forward(
|
480 |
+
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
|
481 |
)
|
482 |
return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
|
483 |
|
|
|
490 |
past_key_values,
|
491 |
block_idx: Union[torch.Tensor, int],
|
492 |
attn_maps: dict = {},
|
493 |
+
return_attn: bool = False,
|
494 |
):
|
495 |
x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
|
496 |
for idx, block in enumerate(self.transformer.core_block, start=1):
|
497 |
+
x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
|
498 |
attn_maps[block_idx + idx] = attn_map
|
499 |
return x, block_idx + idx, attn_maps
|
500 |
|
|
|
627 |
model_inputs["cache_position"] = cache_position
|
628 |
current_input_length = input_ids.shape[1]
|
629 |
if past_key_values is not None:
|
630 |
+
if type(past_key_values) != HuginnDynamicCache:
|
631 |
# Need to use custom cache, detect and replace HF dynamic cache if generate injects it
|
632 |
assert past_key_values.get_seq_length() == 0
|
633 |
past_key_values = HuginnDynamicCache()
|
|
|
647 |
model_inputs[key] = value
|
648 |
return model_inputs
|
649 |
|
650 |
+
@torch.no_grad()
|
651 |
+
def generate(self, *args, **kwargs):
|
652 |
+
"""Dispatcher - use HF generate in all normal cases."""
|
653 |
+
if any(
|
654 |
+
k in kwargs
|
655 |
+
for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
|
656 |
+
):
|
657 |
+
print("Dispatching to custom generate function call")
|
658 |
+
return self.generate_with_adaptive_compute(*args, **kwargs)
|
659 |
+
else:
|
660 |
+
return super().generate(*args, **kwargs)
|
661 |
+
|
662 |
@torch.no_grad()
|
663 |
def generate_minimal(
|
664 |
self,
|