AlgoAnalytics commited on
Commit
2811206
·
verified ·
1 Parent(s): 92ca8d4

Create modeling_phi3.py

Browse files
Files changed (1) hide show
  1. modeling_phi3.py +26 -0
modeling_phi3.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers.modeling_utils import PreTrainedModel
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+
6
+ class Phi3ForCausalLM(PreTrainedModel):
7
+ config_class = Phi3Config
8
+ base_model_prefix = "phi3"
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ self.hidden_size = config.hidden_size
13
+ self.num_hidden_layers = config.num_hidden_layers
14
+ self.num_attention_heads = config.num_attention_heads
15
+
16
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
17
+ self.layers = nn.ModuleList([nn.TransformerEncoderLayer(config.hidden_size, config.num_attention_heads) for _ in range(config.num_hidden_layers)])
18
+ self.output_layer = nn.Linear(config.hidden_size, config.vocab_size)
19
+
20
+ def forward(self, input_ids):
21
+ embeddings = self.embedding(input_ids)
22
+ hidden_states = embeddings
23
+ for layer in self.layers:
24
+ hidden_states = layer(hidden_states)
25
+ logits = self.output_layer(hidden_states)
26
+ return BaseModelOutput(last_hidden_state=logits)