JonasGeiping commited on
Commit
a7ca3c6
·
verified ·
1 Parent(s): 1bc54f1

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +26 -10
raven_modeling_minimal.py CHANGED
@@ -11,7 +11,7 @@ from .raven_config_minimal import RavenConfig
11
  from transformers.cache_utils import Cache, DynamicCache
12
 
13
  ###################### Huggingface Glue code I ##################################################################
14
- from transformers import PreTrainedModel
15
  from transformers.utils import ModelOutput
16
  from transformers.generation.utils import GenerateDecoderOnlyOutput
17
 
@@ -32,7 +32,8 @@ class RavenPreTrainedModel(PreTrainedModel):
32
  _supports_static_cache = False
33
 
34
  def _init_weights(self, module):
35
- print("Random Initialization not implemented.")
 
36
 
37
 
38
  @dataclass
@@ -309,7 +310,7 @@ class SandwichBlock(torch.nn.Module):
309
  return x, attn_map
310
 
311
 
312
- class RavenForCausalLM(RavenPreTrainedModel):
313
  def __init__(
314
  self,
315
  config: RavenConfig,
@@ -367,7 +368,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
367
  "return_latents": True,
368
  "return_attention": False,
369
  "return_head": False,
370
- "return_stats": True,
371
  },
372
  use_cache: bool = False,
373
  cache_position: Optional[torch.Tensor] = None,
@@ -395,7 +396,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
395
  # Non-recurrent prelude
396
  for block_idx, block in enumerate(self.transformer.prelude):
397
  input_embeds, attn_map = block(
398
- input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
399
  )
400
  attn_maps[block_idx] = attn_map
401
 
@@ -409,12 +410,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
409
  past_key_values,
410
  num_steps,
411
  attn_maps,
 
412
  )
413
  latent_states = x.clone().detach()
414
 
415
  # Coda layers
416
  for block_idx, block in enumerate(self.transformer.coda, start=1):
417
- x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn)
418
  attn_maps[-block_idx] = attn_map
419
  x = self.transformer.ln_f(x)
420
 
@@ -451,6 +453,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
451
  past_key_values: Optional[Cache] = None,
452
  num_steps: Optional[torch.Tensor] = None,
453
  attn_maps: dict = {},
 
454
  ):
455
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
456
  if num_steps is None:
@@ -468,13 +471,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
468
  for step in range(num_steps_no_grad):
469
  xk = x
470
  x, block_idx, attn_maps = self.core_block_forward(
471
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
472
  )
473
 
474
  for step in range(num_steps_with_grad):
475
  xk = x
476
  x, block_idx, attn_maps = self.core_block_forward(
477
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
478
  )
479
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
480
 
@@ -487,10 +490,11 @@ class RavenForCausalLM(RavenPreTrainedModel):
487
  past_key_values,
488
  block_idx: Union[torch.Tensor, int],
489
  attn_maps: dict = {},
 
490
  ):
491
  x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
492
  for idx, block in enumerate(self.transformer.core_block, start=1):
493
- x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=len(attn_maps) > 0)
494
  attn_maps[block_idx + idx] = attn_map
495
  return x, block_idx + idx, attn_maps
496
 
@@ -623,7 +627,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
623
  model_inputs["cache_position"] = cache_position
624
  current_input_length = input_ids.shape[1]
625
  if past_key_values is not None:
626
- if type(past_key_values) == DynamicCache:
627
  # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
628
  assert past_key_values.get_seq_length() == 0
629
  past_key_values = HuginnDynamicCache()
@@ -643,6 +647,18 @@ class RavenForCausalLM(RavenPreTrainedModel):
643
  model_inputs[key] = value
644
  return model_inputs
645
 
 
 
 
 
 
 
 
 
 
 
 
 
646
  @torch.no_grad()
647
  def generate_minimal(
648
  self,
 
11
  from transformers.cache_utils import Cache, DynamicCache
12
 
13
  ###################### Huggingface Glue code I ##################################################################
14
+ from transformers import PreTrainedModel, GenerationMixin
15
  from transformers.utils import ModelOutput
16
  from transformers.generation.utils import GenerateDecoderOnlyOutput
17
 
 
32
  _supports_static_cache = False
33
 
34
  def _init_weights(self, module):
35
+ if not torch.rand((1,)).is_meta:
36
+ print("Random Initialization not implemented.")
37
 
38
 
39
  @dataclass
 
310
  return x, attn_map
311
 
312
 
313
+ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
314
  def __init__(
315
  self,
316
  config: RavenConfig,
 
368
  "return_latents": True,
369
  "return_attention": False,
370
  "return_head": False,
371
+ "return_stats": False,
372
  },
373
  use_cache: bool = False,
374
  cache_position: Optional[torch.Tensor] = None,
 
396
  # Non-recurrent prelude
397
  for block_idx, block in enumerate(self.transformer.prelude):
398
  input_embeds, attn_map = block(
399
+ input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn=return_attn
400
  )
401
  attn_maps[block_idx] = attn_map
402
 
 
410
  past_key_values,
411
  num_steps,
412
  attn_maps,
413
+ return_attn=return_attn,
414
  )
415
  latent_states = x.clone().detach()
416
 
417
  # Coda layers
418
  for block_idx, block in enumerate(self.transformer.coda, start=1):
419
+ x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn=return_attn)
420
  attn_maps[-block_idx] = attn_map
421
  x = self.transformer.ln_f(x)
422
 
 
453
  past_key_values: Optional[Cache] = None,
454
  num_steps: Optional[torch.Tensor] = None,
455
  attn_maps: dict = {},
456
+ return_attn: bool = False,
457
  ):
458
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
459
  if num_steps is None:
 
471
  for step in range(num_steps_no_grad):
472
  xk = x
473
  x, block_idx, attn_maps = self.core_block_forward(
474
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
475
  )
476
 
477
  for step in range(num_steps_with_grad):
478
  xk = x
479
  x, block_idx, attn_maps = self.core_block_forward(
480
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
481
  )
482
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
483
 
 
490
  past_key_values,
491
  block_idx: Union[torch.Tensor, int],
492
  attn_maps: dict = {},
493
+ return_attn: bool = False,
494
  ):
495
  x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
496
  for idx, block in enumerate(self.transformer.core_block, start=1):
497
+ x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
498
  attn_maps[block_idx + idx] = attn_map
499
  return x, block_idx + idx, attn_maps
500
 
 
627
  model_inputs["cache_position"] = cache_position
628
  current_input_length = input_ids.shape[1]
629
  if past_key_values is not None:
630
+ if type(past_key_values) != HuginnDynamicCache:
631
  # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
632
  assert past_key_values.get_seq_length() == 0
633
  past_key_values = HuginnDynamicCache()
 
647
  model_inputs[key] = value
648
  return model_inputs
649
 
650
+ @torch.no_grad()
651
+ def generate(self, *args, **kwargs):
652
+ """Dispatcher - use HF generate in all normal cases."""
653
+ if any(
654
+ k in kwargs
655
+ for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
656
+ ):
657
+ print("Dispatching to custom generate function call")
658
+ return self.generate_with_adaptive_compute(*args, **kwargs)
659
+ else:
660
+ return super().generate(*args, **kwargs)
661
+
662
  @torch.no_grad()
663
  def generate_minimal(
664
  self,