Update modeling_mmMamba.py
Browse files- modeling_mmMamba.py +4 -2
modeling_mmMamba.py
CHANGED
@@ -421,7 +421,7 @@ class MHA_LM(nn.Module):
|
|
421 |
):
|
422 |
if self.rotary_emb_dim > 0:
|
423 |
q, kv = self.rotary_emb(
|
424 |
-
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
425 |
)
|
426 |
if inference_params is None:
|
427 |
k, v = kv.unbind(dim=-3)
|
@@ -550,7 +550,9 @@ class Mamba2_LM(nn.Module):
|
|
550 |
conv_state, ssm_state = None, None
|
551 |
if inference_params is not None:
|
552 |
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
553 |
-
|
|
|
|
|
554 |
if use_cache and inference_params.seqlen_offset==0:
|
555 |
vkq, new_conv_states = causal_conv1d_fn(
|
556 |
vkq.transpose(1, 2),
|
|
|
421 |
):
|
422 |
if self.rotary_emb_dim > 0:
|
423 |
q, kv = self.rotary_emb(
|
424 |
+
q, kv, seqlen_offset=seqlen_offset[:bsz,...], max_seqlen=rotary_max_seqlen
|
425 |
)
|
426 |
if inference_params is None:
|
427 |
k, v = kv.unbind(dim=-3)
|
|
|
550 |
conv_state, ssm_state = None, None
|
551 |
if inference_params is not None:
|
552 |
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
553 |
+
conv_state = conv_state[:batch, ...]
|
554 |
+
ssm_state = ssm_state[:batch, ...]
|
555 |
+
|
556 |
if use_cache and inference_params.seqlen_offset==0:
|
557 |
vkq, new_conv_states = causal_conv1d_fn(
|
558 |
vkq.transpose(1, 2),
|