JonasGeiping commited on
Commit
6dc53cb
·
verified ·
1 Parent(s): fbd6377

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +17 -13
raven_modeling_minimal.py CHANGED
@@ -1,4 +1,4 @@
1
- """Minimal modeling.py file for HF compatibility and funny zero-shot experiments. Use only for inference."""
2
 
3
  import torch
4
  import math
@@ -289,7 +289,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
289
  attention_mask: Optional[torch.Tensor] = None,
290
  position_ids: Optional[torch.Tensor] = None,
291
  labels: Optional[torch.Tensor] = None,
292
- num_steps_pair: Optional[torch.Tensor] = None,
293
  past_key_values: Optional[Cache] = None,
294
  output_details: dict = {
295
  "return_logits": True,
@@ -302,12 +302,12 @@ class RavenForCausalLM(RavenPreTrainedModel):
302
  cache_position: Optional[torch.Tensor] = None,
303
  **kwargs,
304
  ) -> CausalLMOutputRecurrentLatents:
305
-
306
  if position_ids is None and cache_position is None:
307
  freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
308
  elif position_ids is not None:
309
  freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
310
- elif cache_position is not None: # support HF format
311
  freqs_cis = self.freqs_cis[:, cache_position]
312
 
313
  if input_embeds is None:
@@ -331,7 +331,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
331
  block_idx,
332
  attention_mask,
333
  past_key_values,
334
- num_steps_pair,
335
  )
336
  latent_states = x.clone().detach()
337
 
@@ -371,16 +371,16 @@ class RavenForCausalLM(RavenPreTrainedModel):
371
  block_idx,
372
  mask,
373
  past_key_values: Optional[Cache] = None,
374
- num_steps_pair: Optional[torch.Tensor] = None,
375
  ):
376
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
377
 
378
- if num_steps_pair is None:
379
  num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
380
- elif len(num_steps_pair) > 1:
381
- num_steps_no_grad, num_steps_with_grad = num_steps_pair
382
  else:
383
- num_steps_no_grad, num_steps_with_grad = num_steps_pair, torch.tensor(0)
384
 
385
  with torch.no_grad():
386
  # ultra annoying in ddp due to
@@ -421,13 +421,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
421
 
422
  return n.to(dtype=torch.long), k.to(dtype=torch.long)
423
 
424
- def initialize_state(self, input_embeds):
425
  x = torch.randn_like(input_embeds)
426
  std = self.config.init_values["std"]
427
  torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
428
  if self.emb_scale != 1:
429
  x = x * self.emb_scale
430
- return x
431
 
432
  def prepare_inputs_for_generation(
433
  self,
@@ -442,7 +442,11 @@ class RavenForCausalLM(RavenPreTrainedModel):
442
  model_inputs["cache_position"] = cache_position
443
  current_input_length = input_ids.shape[1]
444
  if past_key_values is not None:
445
- model_inputs["past_key_values"] = past_key_values
 
 
 
 
446
  input_ids = input_ids[:, cache_position] # type: ignore
447
  model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
448
 
 
1
+ """Minimal modeling.py file for HF compatibility and funny zero-shot experiments. Usability for finetuning not guaranteed"""
2
 
3
  import torch
4
  import math
 
289
  attention_mask: Optional[torch.Tensor] = None,
290
  position_ids: Optional[torch.Tensor] = None,
291
  labels: Optional[torch.Tensor] = None,
292
+ num_steps: Optional[torch.Tensor] = None,
293
  past_key_values: Optional[Cache] = None,
294
  output_details: dict = {
295
  "return_logits": True,
 
302
  cache_position: Optional[torch.Tensor] = None,
303
  **kwargs,
304
  ) -> CausalLMOutputRecurrentLatents:
305
+ # Support multiple position formats:
306
  if position_ids is None and cache_position is None:
307
  freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
308
  elif position_ids is not None:
309
  freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
310
+ elif cache_position is not None:
311
  freqs_cis = self.freqs_cis[:, cache_position]
312
 
313
  if input_embeds is None:
 
331
  block_idx,
332
  attention_mask,
333
  past_key_values,
334
+ num_steps,
335
  )
336
  latent_states = x.clone().detach()
337
 
 
371
  block_idx,
372
  mask,
373
  past_key_values: Optional[Cache] = None,
374
+ num_steps: Optional[torch.Tensor] = None,
375
  ):
376
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
377
 
378
+ if num_steps is None:
379
  num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
380
+ elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
381
+ num_steps_no_grad, num_steps_with_grad = num_steps
382
  else:
383
+ num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0)
384
 
385
  with torch.no_grad():
386
  # ultra annoying in ddp due to
 
421
 
422
  return n.to(dtype=torch.long), k.to(dtype=torch.long)
423
 
424
+ def initialize_state(self, input_embeds, deterministic: bool = False):
425
  x = torch.randn_like(input_embeds)
426
  std = self.config.init_values["std"]
427
  torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
428
  if self.emb_scale != 1:
429
  x = x * self.emb_scale
430
+ return x if not deterministic else x.zero_()
431
 
432
  def prepare_inputs_for_generation(
433
  self,
 
442
  model_inputs["cache_position"] = cache_position
443
  current_input_length = input_ids.shape[1]
444
  if past_key_values is not None:
445
+ if type(past_key_values) == DynamicCache:
446
+ # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
447
+ assert past_key_values.get_seq_length() == 0
448
+ past_key_values = HuginnDynamicCache()
449
+ model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
450
  input_ids = input_ids[:, cache_position] # type: ignore
451
  model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
452