HongyuanTao commited on
Commit
1198b4c
·
verified ·
1 Parent(s): 975bc02

Update modeling_mmMamba.py

Browse files
Files changed (1) hide show
  1. modeling_mmMamba.py +4 -2
modeling_mmMamba.py CHANGED
@@ -421,7 +421,7 @@ class MHA_LM(nn.Module):
421
  ):
422
  if self.rotary_emb_dim > 0:
423
  q, kv = self.rotary_emb(
424
- q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
425
  )
426
  if inference_params is None:
427
  k, v = kv.unbind(dim=-3)
@@ -550,7 +550,9 @@ class Mamba2_LM(nn.Module):
550
  conv_state, ssm_state = None, None
551
  if inference_params is not None:
552
  conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
553
-
 
 
554
  if use_cache and inference_params.seqlen_offset==0:
555
  vkq, new_conv_states = causal_conv1d_fn(
556
  vkq.transpose(1, 2),
 
421
  ):
422
  if self.rotary_emb_dim > 0:
423
  q, kv = self.rotary_emb(
424
+ q, kv, seqlen_offset=seqlen_offset[:bsz,...], max_seqlen=rotary_max_seqlen
425
  )
426
  if inference_params is None:
427
  k, v = kv.unbind(dim=-3)
 
550
  conv_state, ssm_state = None, None
551
  if inference_params is not None:
552
  conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
553
+ conv_state = conv_state[:batch, ...]
554
+ ssm_state = ssm_state[:batch, ...]
555
+
556
  if use_cache and inference_params.seqlen_offset==0:
557
  vkq, new_conv_states = causal_conv1d_fn(
558
  vkq.transpose(1, 2),