jatingocodeo commited on
Commit
1a2e215
·
verified ·
1 Parent(s): c88c76b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py CHANGED
@@ -53,6 +53,72 @@ class SmolLM2Config(PretrainedConfig):
53
  from transformers import AutoConfig
54
  AutoConfig.register("smollm2", SmolLM2Config)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class SmolLM2ForCausalLM(PreTrainedModel):
57
  config_class = SmolLM2Config
58
 
 
53
  from transformers import AutoConfig
54
  AutoConfig.register("smollm2", SmolLM2Config)
55
 
56
+ class RMSNorm(nn.Module):
57
+ def __init__(self, hidden_size, eps=1e-5):
58
+ super().__init__()
59
+ self.weight = nn.Parameter(torch.ones(hidden_size))
60
+ self.eps = eps
61
+
62
+ def forward(self, x):
63
+ variance = x.pow(2).mean(-1, keepdim=True)
64
+ x = x * torch.rsqrt(variance + self.eps)
65
+ return self.weight * x
66
+
67
+ class LlamaDecoderLayer(nn.Module):
68
+ def __init__(self, config):
69
+ super().__init__()
70
+ self.hidden_size = config.hidden_size
71
+ self.num_heads = config.num_attention_heads
72
+ self.head_dim = config.hidden_size // config.num_attention_heads
73
+
74
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
75
+ self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
76
+ self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
77
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
78
+
79
+ self.mlp = nn.Sequential(
80
+ nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
81
+ nn.SiLU(),
82
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
83
+ )
84
+
85
+ self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
86
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
87
+
88
+ def forward(self, hidden_states, attention_mask=None):
89
+ # Self Attention
90
+ residual = hidden_states
91
+ hidden_states = self.input_layernorm(hidden_states)
92
+
93
+ # Reshape for attention
94
+ batch_size, seq_length, _ = hidden_states.size()
95
+ q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
96
+ k = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
97
+ v = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
98
+
99
+ # Compute attention scores
100
+ scale = 1.0 / math.sqrt(self.head_dim)
101
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
102
+
103
+ if attention_mask is not None:
104
+ scores = scores + attention_mask
105
+
106
+ attn_weights = F.softmax(scores, dim=-1)
107
+ hidden_states = torch.matmul(attn_weights, v)
108
+
109
+ # Reshape back
110
+ hidden_states = hidden_states.transpose(1, 2).contiguous().view(batch_size, seq_length, -1)
111
+ hidden_states = self.o_proj(hidden_states)
112
+ hidden_states = residual + hidden_states
113
+
114
+ # MLP
115
+ residual = hidden_states
116
+ hidden_states = self.post_attention_layernorm(hidden_states)
117
+ hidden_states = self.mlp(hidden_states)
118
+ hidden_states = residual + hidden_states
119
+
120
+ return hidden_states
121
+
122
  class SmolLM2ForCausalLM(PreTrainedModel):
123
  config_class = SmolLM2Config
124