fix(phi-1_5): Checks length of `attention_mask`if it is passed as direct tensor.
Browse files
modeling_mixformer_sequential.py
CHANGED
@@ -35,7 +35,7 @@ from __future__ import annotations
|
|
35 |
|
36 |
import math
|
37 |
import copy
|
38 |
-
from typing import Any, Dict, Optional, Tuple
|
39 |
from dataclasses import dataclass, field
|
40 |
|
41 |
import torch
|
@@ -541,8 +541,8 @@ class MHA(nn.Module):
|
|
541 |
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
|
542 |
|
543 |
if attention_mask is not None:
|
544 |
-
attention_mask
|
545 |
-
attention_mask = attention_mask.to(qkv.device)
|
546 |
|
547 |
attention_kwargs = {"attention_mask": attention_mask}
|
548 |
|
|
|
35 |
|
36 |
import math
|
37 |
import copy
|
38 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
39 |
from dataclasses import dataclass, field
|
40 |
|
41 |
import torch
|
|
|
541 |
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
|
542 |
|
543 |
if attention_mask is not None:
|
544 |
+
attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
|
545 |
+
attention_mask = attention_mask.bool().to(qkv.device)
|
546 |
|
547 |
attention_kwargs = {"attention_mask": attention_mask}
|
548 |
|