AkashDataScience commited on
Commit
3821e91
·
1 Parent(s): 83a075b

First commit

Browse files
Files changed (5) hide show
  1. app.py +58 -0
  2. input.txt +0 -0
  3. model.py +149 -0
  4. nanogpt.pth +3 -0
  5. requirements.txt +81 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from model import BigramLanguageModel
4
+
5
+ cuda = torch.cuda.is_available()
6
+ device = 'cuda' if cuda else 'cpu'
7
+
8
+ model = BigramLanguageModel()
9
+ model.load_state_dict(torch.load("model.pth", map_location=torch.device(device)), strict=False)
10
+
11
+ # read text file
12
+ with open('input.txt', 'r', encoding='utf-8') as f:
13
+ text = f.read()
14
+
15
+ # collect all the unique characters that occur in this text
16
+ chars = sorted(list(set(text)))
17
+ vocab_size = len(chars)
18
+
19
+ # create a maaping from charaters that occur in this text
20
+ stoi = { ch:i for i,ch in enumerate(chars) }
21
+ itos = { i:ch for i,ch in enumerate(chars) }
22
+ encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
23
+ decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a
24
+
25
+ def inference(input_text, max_new_tokens=500):
26
+ context = torch.tensor(encode(input_text), dtype=torch.long, device=device)
27
+
28
+ output_text = decode(model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
29
+
30
+ return output_text
31
+
32
+ title = "NanoGPT trained on Shakespeare Plays dataset"
33
+ description = "A simple Gradio interface to generate text from gpt model trained on Shakespeare Plays"
34
+ examples = [["Shape", 500],
35
+ ["Answer", 500],
36
+ ["Ideology", 500],
37
+ ["Absorb", 500],
38
+ ["Triangle", 500],
39
+ ["Listen", 500],
40
+ ["Census", 500],
41
+ ["Balance", 500],
42
+ ["Representative", 500],
43
+ ["Cinema", 500],
44
+ ]
45
+ demo = gr.Interface(
46
+ inference,
47
+ inputs = [
48
+ gr.Textbox(label="Enter any word", type="text"),
49
+ gr.Slider(minimum=100, maximum=10000, step=100, value=500, label="Max character to generate")
50
+ ],
51
+ outputs = [
52
+ gr.Textbox(label="Output", type="text")
53
+ ],
54
+ title = title,
55
+ description = description,
56
+ examples = examples,
57
+ )
58
+ demo.launch()
input.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ n_embd = 384
6
+ block_size = 256
7
+ dropout = 0.1
8
+ n_head = 8
9
+ n_layer = 6
10
+
11
+ # read text file
12
+ with open('input.txt', 'r', encoding='utf-8') as f:
13
+ text = f.read()
14
+
15
+ # collect all the unique characters that occur in this text
16
+ chars = sorted(list(set(text)))
17
+ vocab_size = len(chars)
18
+
19
+ cuda = torch.cuda.is_available()
20
+ device = 'cuda' if cuda else 'cpu'
21
+
22
+ class Head(nn.Module):
23
+ """ one head of self-attention"""
24
+
25
+ def __init__(self, head_size):
26
+ super().__init__()
27
+ self.key = nn.Linear(n_embd, head_size, bias=False)
28
+ self.query = nn.Linear(n_embd, head_size, bias=False)
29
+ self.value = nn.Linear(n_embd, head_size, bias=False)
30
+ self.register_buffer('trill', torch.tril(torch.ones(block_size, block_size)))
31
+ self.dropout = nn.Dropout(dropout)
32
+
33
+ def forward(self, x):
34
+ B,T,C = x.shape
35
+ k = self.key(x)
36
+ q = self.query(x)
37
+ # compute attention scores ("affinities")
38
+ wei = q @ k.transpose(-2, -1) * C**-0.5
39
+ wei = wei.masked_fill(self.trill[:T, :T] == 0, float('-inf'))
40
+ wei = F.softmax(wei, dim=-1)
41
+ wei = self.dropout(wei)
42
+ # perform the weighted aggregation of the values
43
+ v = self.value(x)
44
+ out = wei @ v
45
+ return out
46
+
47
+ class MultiHeadAttention(nn.Module):
48
+ """ multiple heads of self-attention in parallel """
49
+
50
+ def __init__(self, num_heads, head_size):
51
+ super().__init__()
52
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
53
+ self.proj = nn.Linear(n_embd, n_embd)
54
+ self.dropout = nn.Dropout(dropout)
55
+
56
+ def forward(self, x):
57
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
58
+ out = self.dropout(self.proj(out))
59
+ return out
60
+
61
+ class FeedForward(nn.Module):
62
+ """ a simple linear layer followed by a non-linearity """
63
+
64
+ def __init__(self, n_embd):
65
+ super().__init__()
66
+ self.net = nn.Sequential(
67
+ nn.Linear(n_embd, 4 * n_embd),
68
+ nn.ReLU(),
69
+ nn.Linear(4 * n_embd, n_embd),
70
+ nn.Dropout(dropout)
71
+ )
72
+
73
+ def forward(self, x):
74
+ return self.net(x)
75
+
76
+ class Block(nn.Module):
77
+ """ Transformer block: communication followed by computation"""
78
+
79
+ def __init__(self, n_embd, n_head) -> None:
80
+
81
+ super().__init__()
82
+ head_size = n_embd // n_head
83
+ self.sa = MultiHeadAttention(n_head, head_size)
84
+ self.ffwd = FeedForward(n_embd)
85
+ self.ln1 = nn.LayerNorm(n_embd)
86
+ self.ln2 = nn.LayerNorm(n_embd)
87
+
88
+ def forward(self, x):
89
+ x = x + self.sa(self.ln1(x))
90
+ x = x + self.ffwd(self.ln2(x))
91
+ return (x)
92
+
93
+
94
+ class BigramLanguageModel(nn.Module):
95
+
96
+ def __init__(self):
97
+ super().__init__()
98
+ # each token directly reads off the logits for the next token from a lookup table
99
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
100
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
101
+ # self.sa_head = MultiHeadAttention(4, n_embd//4)
102
+ # self.ffwd = FeedForward(n_embd)
103
+ # self.blocks = nn.Sequential(
104
+ # Block(n_embd, n_head=4),
105
+ # Block(n_embd, n_head=4),
106
+ # Block(n_embd, n_head=4),
107
+ # nn.LayerNorm(n_embd)
108
+ # )
109
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range (n_layer)])
110
+ self.ln_f = nn.LayerNorm(n_embd)
111
+ self.lm_head = nn.Linear(n_embd, vocab_size)
112
+
113
+ def forward(self, idx, targets=None):
114
+ B, T = idx.shape
115
+ # idx and targets are both (B,T) tensor of integers
116
+ tok_emb = self.token_embedding_table(idx) # (B,T,C)
117
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device))
118
+ x = tok_emb + pos_emb
119
+ # x = self.sa_head(x)
120
+ # x = self.ffwd(x)
121
+ x = self.blocks(x)
122
+ x = self.ln_f(x)
123
+ logits = self.lm_head(x)
124
+
125
+ if targets is None:
126
+ loss = None
127
+ else:
128
+ B, T, C = logits.shape
129
+ logits = logits.view(B*T, C)
130
+ targets = targets.view(B*T)
131
+ loss = F.cross_entropy(logits, targets)
132
+
133
+ return logits, loss
134
+
135
+ def generate(self, idx, max_new_tokens):
136
+ # idx is (B, T) array of indices in the current context
137
+ for _ in range(max_new_tokens):
138
+ idx_cond = idx[:, -block_size:]
139
+ # get the predictions
140
+ logits, loss = self(idx_cond)
141
+ # focus only on the last time step
142
+ logits = logits[:, -1, :] # becomes (B, C)
143
+ # apply softmax to get probabilities
144
+ probs = F.softmax(logits, dim=-1) # (B, C)
145
+ # sample from the distribution
146
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
147
+ # append sampled index to the running sequence
148
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
149
+ return idx
nanogpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf87f20f3e71ff9fbd3781c4f4969cbeebe0f73ec85838b7edac1c2e7f86dd29
3
+ size 55835502
requirements.txt ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.3.0
3
+ annotated-types==0.7.0
4
+ anyio==4.4.0
5
+ attrs==23.2.0
6
+ certifi==2024.6.2
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ colorama==0.4.6
10
+ contourpy==1.2.1
11
+ cycler==0.12.1
12
+ dnspython==2.6.1
13
+ email_validator==2.1.1
14
+ fastapi==0.111.0
15
+ fastapi-cli==0.0.4
16
+ ffmpy==0.3.2
17
+ filelock==3.13.1
18
+ fonttools==4.53.0
19
+ fsspec==2024.2.0
20
+ gradio==4.36.1
21
+ gradio_client==1.0.1
22
+ h11==0.14.0
23
+ httpcore==1.0.5
24
+ httptools==0.6.1
25
+ httpx==0.27.0
26
+ huggingface-hub==0.23.4
27
+ idna==3.7
28
+ importlib_resources==6.4.0
29
+ intel-openmp==2021.4.0
30
+ Jinja2==3.1.3
31
+ jsonschema==4.22.0
32
+ jsonschema-specifications==2023.12.1
33
+ kiwisolver==1.4.5
34
+ markdown-it-py==3.0.0
35
+ MarkupSafe==2.1.5
36
+ matplotlib==3.9.0
37
+ mdurl==0.1.2
38
+ mkl==2021.4.0
39
+ mpmath==1.3.0
40
+ networkx==3.2.1
41
+ numpy==1.26.3
42
+ orjson==3.10.5
43
+ packaging==24.1
44
+ pandas==2.2.2
45
+ pillow==10.2.0
46
+ pydantic==2.7.4
47
+ pydantic_core==2.18.4
48
+ pydub==0.25.1
49
+ Pygments==2.18.0
50
+ pyparsing==3.1.2
51
+ python-dateutil==2.9.0.post0
52
+ python-dotenv==1.0.1
53
+ python-multipart==0.0.9
54
+ pytz==2024.1
55
+ PyYAML==6.0.1
56
+ referencing==0.35.1
57
+ requests==2.32.3
58
+ rich==13.7.1
59
+ rpds-py==0.18.1
60
+ ruff==0.4.9
61
+ semantic-version==2.10.0
62
+ shellingham==1.5.4
63
+ six==1.16.0
64
+ sniffio==1.3.1
65
+ starlette==0.37.2
66
+ sympy==1.12
67
+ tbb==2021.11.0
68
+ tomlkit==0.12.0
69
+ toolz==0.12.1
70
+ torch==2.3.1
71
+ torchaudio==2.3.1
72
+ torchvision==0.18.1
73
+ tqdm==4.66.4
74
+ typer==0.12.3
75
+ typing_extensions==4.9.0
76
+ tzdata==2024.1
77
+ ujson==5.10.0
78
+ urllib3==2.2.1
79
+ uvicorn==0.30.1
80
+ watchfiles==0.22.0
81
+ websockets==11.0.3