KittyCat00 commited on
Commit
a0b0bb4
·
verified ·
1 Parent(s): ad443fa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +520 -0
app.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ import torch
3
+ import time
4
+ import math
5
+ from torch.utils.data import Dataset, DataLoader
6
+
7
+ import gradio as gr
8
+ import torch.nn as nn
9
+
10
+ class GPTModel(nn.Module):
11
+
12
+ def __init__(self, cfg):
13
+ super().__init__()
14
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
15
+ self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
16
+ self.drop_emb = nn.Dropout(cfg["drop_rate"])
17
+
18
+ self.trf_blocks = nn.Sequential(
19
+ *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
20
+ )
21
+
22
+ self.final_norm = LayerNorm(cfg["emb_dim"])
23
+ self.out_head = nn.Linear(
24
+ cfg["emb_dim"], cfg["vocab_size"], bias=False
25
+ )
26
+
27
+ def forward(self, in_idx):
28
+ batch_size, seq_len = in_idx.shape
29
+ tok_embeds = self.tok_emb(in_idx)
30
+ pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
31
+ x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
32
+ x = self.drop_emb(x)
33
+ x = self.trf_blocks(x)
34
+ x = self.final_norm(x)
35
+ logits = self.out_head(x)
36
+ return logits
37
+
38
+ class TransformerBlock(nn.Module):
39
+
40
+ def __init__(self, cfg):
41
+ super().__init__()
42
+ self.att = MultiHeadAttention(
43
+ d_in=cfg["emb_dim"],
44
+ d_out=cfg["emb_dim"],
45
+ context_length=cfg["context_length"],
46
+ num_heads=cfg["n_heads"],
47
+ dropout=cfg["drop_rate"],
48
+ qkv_bias=cfg["qkv_bias"]
49
+ )
50
+ self.ff = FeedForward(cfg)
51
+ self.norm1 = LayerNorm(cfg["emb_dim"])
52
+ self.norm2 = LayerNorm(cfg["emb_dim"])
53
+ self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
54
+
55
+ def forward(self, x):
56
+ # Shortcut connection for attnetion block
57
+ shortcut = x
58
+ x = self.norm1(x)
59
+ x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
60
+ x = self.drop_shortcut(x)
61
+ x = x + shortcut # Add the original input back
62
+
63
+ # Shortcut connection for feed forward block
64
+ shortcut = x
65
+ x = self.norm2(x)
66
+ x = self.ff(x)
67
+ x = self.drop_shortcut(x)
68
+ x = x + shortcut # Add the original input back
69
+
70
+ return x
71
+
72
+ class TransformerBlock(nn.Module):
73
+
74
+ def __init__(self, cfg):
75
+ super().__init__()
76
+ self.att = MultiHeadAttention(
77
+ d_in=cfg["emb_dim"],
78
+ d_out=cfg["emb_dim"],
79
+ context_length=cfg["context_length"],
80
+ num_heads=cfg["n_heads"],
81
+ dropout=cfg["drop_rate"],
82
+ qkv_bias=cfg["qkv_bias"]
83
+ )
84
+ self.ff = FeedForward(cfg)
85
+ self.norm1 = LayerNorm(cfg["emb_dim"])
86
+ self.norm2 = LayerNorm(cfg["emb_dim"])
87
+ self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
88
+
89
+ def forward(self, x):
90
+ # Shortcut connection for attnetion block
91
+ shortcut = x
92
+ x = self.norm1(x)
93
+ x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
94
+ x = self.drop_shortcut(x)
95
+ x = x + shortcut # Add the original input back
96
+
97
+ # Shortcut connection for feed forward block
98
+ shortcut = x
99
+ x = self.norm2(x)
100
+ x = self.ff(x)
101
+ x = self.drop_shortcut(x)
102
+ x = x + shortcut # Add the original input back
103
+
104
+ return x
105
+
106
+ class MultiHeadAttention(nn.Module):
107
+
108
+ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
109
+ super().__init__()
110
+ assert (d_out % num_heads == 0), \
111
+ "d_out must be divisible by num_heads"
112
+
113
+ self.d_out = d_out
114
+ self.num_heads = num_heads
115
+ self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
116
+
117
+ self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
118
+ self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
119
+ self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
120
+ self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
121
+ self.dropout = nn.Dropout(dropout)
122
+ self.register_buffer(
123
+ "mask",
124
+ torch.triu(torch.ones(context_length, context_length),
125
+ diagonal=1)
126
+ )
127
+
128
+ def forward(self, x):
129
+ b, num_tokens, d_in = x.shape
130
+
131
+ keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
132
+ queries = self.W_query(x)
133
+ values = self.W_value(x)
134
+
135
+ # implicitly split the matrix by adding a `num_heads` dimension
136
+ # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
137
+ keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
138
+ values = values.view(b, num_tokens, self.num_heads, self.head_dim)
139
+ queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
140
+
141
+ # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
142
+ keys = keys.transpose(1, 2)
143
+ queries = queries.transpose(1, 2)
144
+ values = values.transpose(1, 2)
145
+
146
+ # Compute scaled dot-product attention (aka self-attention) with a causal mask
147
+ attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
148
+
149
+ # Original mask truncated to the number of tokens and converted to boolean
150
+ mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
151
+
152
+ # Use the mask to fill attention scores
153
+ attn_scores.masked_fill_(mask_bool, -torch.inf)
154
+
155
+ attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
156
+ attn_weights = self.dropout(attn_weights)
157
+
158
+ # Shape: (b, num_tokens, num_heads, head_dim)
159
+ context_vec = (attn_weights @ values).transpose(1, 2)
160
+
161
+ # Combine heads, where self.d_out = self.num_heads * self.head_dim
162
+ context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
163
+ context_vec = self.out_proj(context_vec) # optional projection
164
+
165
+ return context_vec
166
+
167
+ class FeedForward(nn.Module):
168
+
169
+ def __init__(self, cfg):
170
+ super().__init__()
171
+ self.layers = nn.Sequential(
172
+ nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
173
+ GELU(),
174
+ nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"])
175
+ )
176
+
177
+ def forward(self, x):
178
+ return self.layers(x)
179
+
180
+ class GELU(nn.Module):
181
+
182
+ def __init__(self):
183
+ super().__init__()
184
+
185
+ def forward(self, x):
186
+ return 0.5 * x * (1 + torch.tanh(
187
+ torch.sqrt(torch.tensor(2.0 / torch.pi)) *
188
+ (x + 0.044715 * torch.pow(x, 3))
189
+ ))
190
+
191
+ class LayerNorm(nn.Module):
192
+
193
+ def __init__(self, emb_dim):
194
+ super().__init__()
195
+ self.eps = 1e-5
196
+ self.scale = nn.Parameter(torch.ones(emb_dim))
197
+ self.shift = nn.Parameter(torch.zeros(emb_dim))
198
+
199
+ def forward(self, x):
200
+ mean = x.mean(dim=-1, keepdim=True)
201
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
202
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
203
+ return self.scale * norm_x + self.shift
204
+
205
+
206
+
207
+
208
+ GPT_CONFIG_124M = {
209
+ "vocab_size": 50257, # Vocabulary size
210
+ "context_length": 256, # Shortended context length (orig: 1024)
211
+ "emb_dim": 768, # Embedding dimension
212
+ "n_heads": 12, # Number of attention heads
213
+ "n_layers": 12, # Number of layers
214
+ "drop_rate": 0.1, # Dropout rate
215
+ "qkv_bias": False # Query-key-value bias
216
+ }
217
+
218
+ model = GPTModel(GPT_CONFIG_124M)
219
+
220
+ def generate(model, idx, max_new_tokens, context_size, tokenizer, text_to_token_ids, temperature=0.0, top_k=None, eos_id=None):
221
+
222
+ # For-loop is the same as before: Get logits, and only focus on last time step
223
+ for _ in range(max_new_tokens):
224
+ idx_cond = idx[:, -context_size:]
225
+ with torch.no_grad():
226
+ logits = model(idx_cond)
227
+ logits = logits[:, -1, :]
228
+
229
+ # New: Filter logits with top_k sampling
230
+ if top_k is not None:
231
+ # Keep only top_k values
232
+ top_logits, _ = torch.topk(logits, top_k)
233
+ min_val = top_logits[:, -1]
234
+ logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
235
+
236
+ # New: Apply temperature scaling
237
+ if temperature > 0.0:
238
+ logits = logits / temperature
239
+
240
+ # Apply softmax to get probabilities
241
+ probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
242
+
243
+ # Sample from the distribution
244
+ idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
245
+
246
+ # Otherwise, same as before: get the idx of the vocab entry with the highest logits value
247
+ else:
248
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
249
+
250
+ if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
251
+ break
252
+
253
+ # if idx_next == text_to_token_ids(".", tokenizer):
254
+ if idx_next == "tensor([[13]])":
255
+ # idx_next = idx_next + text_to_token_ids("Meow.", tokenizer)
256
+ print("\nperiod\n")
257
+
258
+ # if idx_next == text_to_token_ids("?", tokenizer):
259
+ if idx_next == "tensor([[30]])":
260
+ # idx_next = idx_next + text_to_token_ids("Meow.", tokenizer)
261
+ print("\nperiod\n")
262
+
263
+ # if idx_next == text_to_token_ids("!", tokenizer):
264
+ if idx_next == "tensor([[0]])":
265
+ # idx_next = idx_next + text_to_token_ids("Meow.", tokenizer)
266
+ print("\nperiod\n")
267
+
268
+ # print(idx_next)
269
+ # print("----")
270
+ # print(idx_next + text_to_token_ids("Meow.", tokenizer))
271
+ # test = idx_next + text_to_token_ids("Meow.", tokenizer)
272
+ # print("------")
273
+ # print(token_ids_to_text(idx_next, tokenizer))
274
+ # Same as before: append sampled index to the running sequence
275
+ idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
276
+ new_idx = re.sub(".", ". Meow.", idx)
277
+
278
+ return new_idx
279
+
280
+ def text_to_token_ids(text, tokenizer):
281
+ encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
282
+ encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
283
+ return encoded_tensor
284
+
285
+ def token_ids_to_text(token_ids, tokenizer):
286
+ flat = token_ids.squeeze(0) # remove batch dimension
287
+ return tokenizer.decode(flat.tolist())
288
+
289
+ def train_model(model, train_loader, val_loader, optimizer, device,
290
+ n_epochs, eval_freq, eval_iter, start_context, tokenizer,
291
+ warmup_steps, initial_lr=3e-05, min_lr=1e-6):
292
+
293
+ train_losses, val_losses, track_tokens_seen, track_lrs = [], [], [], []
294
+ tokens_seen, global_step = 0, -1
295
+
296
+ # Retrieve the maximum learning rate from the optimizer
297
+ peak_lr = optimizer.param_groups[0]["lr"]
298
+
299
+ # Calculate the total number of iterations in the training process
300
+ total_training_steps = len(train_loader) * n_epochs
301
+
302
+ # Calculate the learning rate increment during the warmup phase
303
+ lr_increment = (peak_lr - initial_lr) / warmup_steps
304
+
305
+ for epoch in range(n_epochs):
306
+ model.train()
307
+ for input_batch, target_batch in train_loader:
308
+ optimizer.zero_grad()
309
+ global_step += 1
310
+
311
+ # Adjust the learning rate based on the current phase (warmup or cosine annealing)
312
+ if global_step < warmup_steps:
313
+ # Linear warmup
314
+ lr = initial_lr + global_step * lr_increment
315
+ else:
316
+ # Cosine annealing after warmup
317
+ progress = ((global_step - warmup_steps) /
318
+ (total_training_steps - warmup_steps))
319
+ lr = min_lr + (peak_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
320
+
321
+ # Apply the calculated learning rate to the optimizer
322
+ for param_group in optimizer.param_groups:
323
+ param_group["lr"] = lr
324
+ track_lrs.append(lr) # Store the current learning rate
325
+
326
+ # Calculate and backpropagate the loss
327
+ loss = calc_loss_batch(input_batch, target_batch, model, device)
328
+ loss.backward()
329
+
330
+ # Apply gradient clipping after the warmup phase to avoid exploding gradients
331
+ if global_step > warmup_steps:
332
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
333
+
334
+ optimizer.step()
335
+ tokens_seen += input_batch.numel()
336
+
337
+ # Periodically evaluate the model on the training and validation sets
338
+ if global_step % eval_freq == 0:
339
+ train_loss, val_loss = evaluate_model(
340
+ model, train_loader, val_loader,
341
+ device, eval_iter
342
+ )
343
+ train_losses.append(train_loss)
344
+ val_losses.append(val_loss)
345
+ track_tokens_seen.append(tokens_seen)
346
+ # Print the current losses
347
+ print(f"Ep {epoch+1} (Iter {global_step:06d}): "
348
+ f"Train loss {train_loss:.3f}, "
349
+ f"Val loss {val_loss:.3f}"
350
+ )
351
+
352
+ # Generate and print a sample from the model to monitor progress
353
+ generate_and_print_sample(
354
+ model, tokenizer, device, start_context
355
+ )
356
+
357
+ return train_losses, val_losses, track_tokens_seen, track_lrs
358
+
359
+ def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True, num_workers=0):
360
+ tokenizer = tiktoken.get_encoding("gpt2") # A - Initalize the tokenizer
361
+ dataset = GPTDatasetV1(txt, tokenizer, max_length, stride) # B - Create dataset
362
+ dataloader = DataLoader(
363
+ dataset,
364
+ batch_size=batch_size,
365
+ shuffle=shuffle,
366
+ drop_last=drop_last, # C - drop_last=True drops the last batch if it is shorter than the specified batch_size to prevent loss spikes during training
367
+ num_workers=0 # D - The number of CPU processes to use for preprocessing
368
+ )
369
+
370
+ return dataloader
371
+
372
+
373
+
374
+ class GPTDatasetV1(Dataset):
375
+ def __init__(self, txt, tokenizer, max_length, stride):
376
+ self.tokenizer = tokenizer
377
+ self.input_ids = []
378
+ self.target_ids = []
379
+
380
+ token_ids = tokenizer.encode(txt) # A
381
+
382
+ for i in range(0, len(token_ids) - max_length, stride): # B
383
+ input_chunk = token_ids[i:i + max_length]
384
+ target_chunk = token_ids[i + 1: i +max_length + 1]
385
+ self.input_ids.append(torch.tensor(input_chunk))
386
+ self.target_ids.append(torch.tensor(target_chunk))
387
+
388
+ def __len__(self):
389
+ return len(self.input_ids)
390
+
391
+ def __getitem__(self, idx):
392
+ return self.input_ids[idx], self.target_ids[idx]
393
+
394
+
395
+ def evaluate_model(model, train_loader, val_loader, device, eval_iter):
396
+ model.eval()
397
+ with torch.no_grad():
398
+ train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
399
+ val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
400
+ model.train()
401
+ return train_loss, val_loss
402
+
403
+ def generate_and_print_sample(model, tokenizer, device, start_context):
404
+ model.eval()
405
+ context_size = model.pos_emb.weight.shape[0]
406
+ encoded = text_to_token_ids(start_context, tokenizer).to(device)
407
+ with torch.no_grad():
408
+ token_ids = generate_text_simple(
409
+ model=model, idx=encoded,
410
+ max_new_tokens=50, context_size=context_size
411
+ )
412
+ decoded_text = token_ids_to_text(token_ids, tokenizer)
413
+ print(decoded_text.replace("\n", " ")) # Compact print format
414
+ model.train()
415
+
416
+ def calc_loss_batch(input_batch, target_batch, model, device):
417
+ input_batch, target_batch = input_batch.to(device), target_batch.to(device)
418
+ logits = model(input_batch)
419
+ loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
420
+ return loss
421
+
422
+ def calc_loss_loader(data_loader, model, device, num_batches=None):
423
+ total_loss = 0.
424
+ if len(data_loader) == 0:
425
+ return float("nan")
426
+ elif num_batches is None:
427
+ num_batches = len(data_loader)
428
+ else:
429
+ # Reduce the number of batches to match the total number of batches in the data loader
430
+ # if num_batches exceeds the number of batches in the data loader
431
+ num_batches = min(num_batches, len(data_loader))
432
+ for i, (input_batch, target_batch) in enumerate(data_loader):
433
+ if i < num_batches:
434
+ loss = calc_loss_batch(input_batch, target_batch, model, device)
435
+ total_loss += loss.item()
436
+ else:
437
+ break
438
+ return total_loss / num_batches
439
+
440
+ def generate_text_simple(model, idx, max_new_tokens, context_size):
441
+ # idx is (batch, n_tokens) array of indices in the current context
442
+ for _ in range(max_new_tokens):
443
+
444
+ # Crop current context if it exceeds the supported context size
445
+ idx_cond = idx[:, -context_size:]
446
+
447
+ # get the predictions
448
+ with torch.no_grad():
449
+ logits = model(idx_cond)
450
+
451
+ # Focus only on the last time step
452
+ # (batch, n_tokens, vocab_size) becomes (batch, vocab_size)
453
+ logits = logits[:, -1, :]
454
+
455
+ # apply softmax to get the probabilities
456
+ probas = torch.softmax(logits, dim=-1) # (batch, vocab_size)
457
+
458
+ # Get the idx of the vocab entry with the highest probability value
459
+ idx_next = torch.argmax(probas, dim=-1, keepdim=True) # (batch, 1)
460
+
461
+ # if idx_next == text_to_token_ids(".", tokenizer):
462
+ # idx_next = idx_next + text_to_token_ids("Meow.", tokenizer)
463
+
464
+ # if idx_next == text_to_token_ids("?", tokenizer):
465
+ # idx_next = idx_next + text_to_token_ids("Meow.", tokenizer)
466
+
467
+ # if idx_next == text_to_token_ids("!", tokenizer):
468
+ # idx_next = idx_next + text_to_token_ids("Meow.", tokenizer)
469
+
470
+ # Append sampled index to the running sequence
471
+ idx = torch.cat((idx, idx_next), dim=1) # (batch , n_tokens+1)
472
+
473
+ return idx
474
+
475
+ def main(input_text, max_new_tokens):
476
+
477
+ tokenizer = tiktoken.get_encoding("gpt2")
478
+
479
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
480
+
481
+ if torch.cuda.is_available():
482
+ device = torch.device("cuda")
483
+ elif torch.backends.mps.is_available():
484
+ device = torch.device("mps")
485
+ else:
486
+ device = torch.device("cpu")
487
+
488
+ weights = torch.load("model_and_optimizer.pth", map_location=torch.device(device))
489
+
490
+ model = GPTModel({
491
+ "vocab_size": 50257, # Vocabulary size
492
+ "context_length": 512, # Shortened context length (orig: 1024)
493
+ "emb_dim": 768, # Embedding dimension
494
+ "n_heads": 12, # Number of attention heads
495
+ "n_layers": 12, # Number of layers
496
+ "drop_rate": 0.3, # Dropout rate
497
+ "qkv_bias": False # Query-key-value bias
498
+ }).to(device)
499
+ model.load_state_dict(weights)
500
+ model.eval()
501
+
502
+ context_size = model.pos_emb.weight.shape[0]
503
+ encoded = torch.tensor(tokenizer.encode(input_text.strip())).unsqueeze(0).to(device)
504
+
505
+ with torch.no_grad():
506
+ token_ids = generate_text_simple(
507
+ model=model, idx=encoded,
508
+ max_new_tokens=max_new_tokens, context_size=context_size,
509
+ top_k=25, temperature=1.4, text_to_token_ids=text_to_token_ids, tokenizer=tokenizer
510
+ )
511
+ return tokenizer.decode(token_ids.squeeze(0).tolist())
512
+
513
+ # if __name__ == "__main__":
514
+ # gr.Interface(fn=main, inputs=[gr.Textbox(label='Starting context'), gr.Number(label="Maximum output tokens")], outputs=[gr.Textbox(label="Response:")], title="CatGPT", article="Meow").launch()
515
+
516
+ thing = gr.Interface(fn=main, theme=gr.themes.Soft(primary_hue="pink", secondary_hue="stone"), inputs=[gr.Textbox(label='Starting context'), gr.Number(label="Maximum output tokens")], outputs=[gr.Textbox(label="Response:")], title="CatGPT", article="Meow")
517
+
518
+
519
+ if __name__ == "__main__":
520
+ thing.launch()