import torch from torch import nn from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput class Phi3ForCausalLM(PreTrainedModel): config_class = Phi3Config base_model_prefix = "phi3" def __init__(self, config): super().__init__(config) self.hidden_size = config.hidden_size self.num_hidden_layers = config.num_hidden_layers self.num_attention_heads = config.num_attention_heads self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([nn.TransformerEncoderLayer(config.hidden_size, config.num_attention_heads) for _ in range(config.num_hidden_layers)]) self.output_layer = nn.Linear(config.hidden_size, config.vocab_size) def forward(self, input_ids): embeddings = self.embedding(input_ids) hidden_states = embeddings for layer in self.layers: hidden_states = layer(hidden_states) logits = self.output_layer(hidden_states) return BaseModelOutput(last_hidden_state=logits)