yoonusajwardapiit commited on
Commit
49b2bf5
1 Parent(s): 2047d88

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -12
app.py CHANGED
@@ -1,33 +1,101 @@
1
-
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
5
 
6
- # Define your custom model class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class BigramLanguageModel(nn.Module):
8
  def __init__(self):
9
  super().__init__()
10
- # Example layers (adjust as needed for your model)
11
  self.token_embedding_table = nn.Embedding(61, 64)
12
  self.position_embedding_table = nn.Embedding(32, 64)
13
- self.blocks = nn.Sequential(*[nn.Linear(64, 64) for _ in range(4)])
14
  self.ln_f = nn.LayerNorm(64)
15
  self.lm_head = nn.Linear(64, 61)
16
 
17
- def forward(self, idx):
18
- # Implement the forward pass
19
- pass
 
 
 
 
 
 
20
 
21
- def generate(self, idx, max_new_tokens=250):
22
- # Implement the generate method
23
- pass
 
 
 
 
 
 
24
 
25
- # Load your model
26
  def load_model():
27
  model = BigramLanguageModel()
28
  model_url = "https://huggingface.co/yoonusajwardapiit/triptuner/resolve/main/pytorch_model.bin"
29
  model_weights = torch.hub.load_state_dict_from_url(model_url, map_location=torch.device('cpu'), weights_only=True)
30
- model.load_state_dict(model_weights)
31
  model.eval()
32
  return model
33
 
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
 
5
+ # Define your custom model class with detailed layer structures
6
+ class Head(nn.Module):
7
+ def __init__(self, head_size):
8
+ super().__init__()
9
+ self.key = nn.Linear(64, head_size, bias=False)
10
+ self.query = nn.Linear(64, head_size, bias=False)
11
+ self.value = nn.Linear(64, head_size, bias=False)
12
+ self.register_buffer('tril', torch.tril(torch.ones(32, 32)))
13
+ self.dropout = nn.Dropout(0.1)
14
+
15
+ def forward(self, x):
16
+ B, T, C = x.shape
17
+ k = self.key(x)
18
+ q = self.query(x)
19
+ wei = q @ k.transpose(-2, -1) * C**-0.5
20
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
21
+ wei = nn.functional.softmax(wei, dim=-1)
22
+ wei = self.dropout(wei)
23
+ v = self.value(x)
24
+ return wei @ v
25
+
26
+ class MultiHeadAttention(nn.Module):
27
+ def __init__(self, num_heads, head_size):
28
+ super().__init__()
29
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
30
+ self.proj = nn.Linear(64, 64)
31
+ self.dropout = nn.Dropout(0.1)
32
+
33
+ def forward(self, x):
34
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
35
+ return self.dropout(self.proj(out))
36
+
37
+ class FeedForward(nn.Module):
38
+ def __init__(self, n_embd):
39
+ super().__init__()
40
+ self.net = nn.Sequential(
41
+ nn.Linear(n_embd, 4 * n_embd),
42
+ nn.ReLU(),
43
+ nn.Linear(4 * n_embd, n_embd),
44
+ nn.Dropout(0.1),
45
+ )
46
+
47
+ def forward(self, x):
48
+ return self.net(x)
49
+
50
+ class Block(nn.Module):
51
+ def __init__(self, n_embd, n_head):
52
+ super().__init__()
53
+ head_size = n_embd // n_head
54
+ self.sa = MultiHeadAttention(n_head, head_size)
55
+ self.ffwd = FeedForward(n_embd)
56
+ self.ln1 = nn.LayerNorm(n_embd)
57
+ self.ln2 = nn.LayerNorm(n_embd)
58
+
59
+ def forward(self, x):
60
+ x = x + self.sa(self.ln1(x))
61
+ x = x + self.ffwd(self.ln2(x))
62
+ return x
63
+
64
  class BigramLanguageModel(nn.Module):
65
  def __init__(self):
66
  super().__init__()
 
67
  self.token_embedding_table = nn.Embedding(61, 64)
68
  self.position_embedding_table = nn.Embedding(32, 64)
69
+ self.blocks = nn.Sequential(*[Block(64, n_head=4) for _ in range(4)])
70
  self.ln_f = nn.LayerNorm(64)
71
  self.lm_head = nn.Linear(64, 61)
72
 
73
+ def forward(self, idx, targets=None):
74
+ B, T = idx.shape
75
+ tok_emb = self.token_embedding_table(idx)
76
+ pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
77
+ x = tok_emb + pos_emb
78
+ x = self.blocks(x)
79
+ x = self.ln_f(x)
80
+ logits = self.lm_head(x)
81
+ return logits, None
82
 
83
+ def generate(self, idx, max_new_tokens):
84
+ for _ in range(max_new_tokens):
85
+ idx_cond = idx[:, -32:]
86
+ logits, _ = self(idx_cond)
87
+ logits = logits[:, -1, :]
88
+ probs = nn.functional.softmax(logits, dim=-1)
89
+ idx_next = torch.multinomial(probs, num_samples=1)
90
+ idx = torch.cat((idx, idx_next), dim=1)
91
+ return idx
92
 
93
+ # Load the model with strict=False to handle missing or unexpected keys
94
  def load_model():
95
  model = BigramLanguageModel()
96
  model_url = "https://huggingface.co/yoonusajwardapiit/triptuner/resolve/main/pytorch_model.bin"
97
  model_weights = torch.hub.load_state_dict_from_url(model_url, map_location=torch.device('cpu'), weights_only=True)
98
+ model.load_state_dict(model_weights, strict=False)
99
  model.eval()
100
  return model
101