Guanzheng commited on
Commit
efe2fd9
·
verified ·
1 Parent(s): 8aa754a

Update modeling_phi2_clex.py

Browse files
Files changed (1) hide show
  1. modeling_phi2_clex.py +2 -5
modeling_phi2_clex.py CHANGED
@@ -59,10 +59,7 @@ logger = logging.get_logger(__name__)
59
  _CHECKPOINT_FOR_DOC = "microsoft/phi-2"
60
  _CONFIG_FOR_DOC = "CLEXPhiConfig"
61
 
62
- PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
- "microsoft/phi-2",
64
- # See all Phi models at https://huggingface.co/models?filter=phi
65
- ]
66
 
67
 
68
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -663,7 +660,7 @@ PHI_ATTENTION_CLASSES = {
663
  class PhiDecoderLayer(nn.Module):
664
  def __init__(self, config: CLEXPhiConfig, layer_idx: int):
665
  super().__init__()
666
- self.self_attn = PhiFlashAttention2(config, layer_idx=layer_idx)
667
  self.mlp = PhiMLP(config)
668
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
669
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
 
59
  _CHECKPOINT_FOR_DOC = "microsoft/phi-2"
60
  _CONFIG_FOR_DOC = "CLEXPhiConfig"
61
 
62
+
 
 
 
63
 
64
 
65
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
660
  class PhiDecoderLayer(nn.Module):
661
  def __init__(self, config: CLEXPhiConfig, layer_idx: int):
662
  super().__init__()
663
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
664
  self.mlp = PhiMLP(config)
665
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
666
  self.resid_dropout = nn.Dropout(config.resid_pdrop)