mkthoma commited on
Commit
65d886a
·
1 Parent(s): b3a0bd3
Files changed (1) hide show
  1. app.py +157 -0
app.py CHANGED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import lightning as L
4
+ from torch.utils.data import DataLoader
5
+ from lightning.fabric.loggers import CSVLogger
6
+ from lightning.fabric.strategies import FSDPStrategy
7
+ from tsai_gpt.model import GPT, Block, Config
8
+ from tsai_gpt.tokenizer import Tokenizer
9
+ from tsai_gpt.packed_dataset import CombinedDataset, PackedDataset
10
+ from tsai_gpt.speed_monitor import SpeedMonitorBase, estimate_flops, measure_flops
11
+ from tsai_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor
12
+ from tsai_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, load_checkpoint, gptq_quantization
13
+ import torch.nn as nn
14
+ from pathlib import Path
15
+ import sys
16
+ import random
17
+ from torch import nn
18
+ import lightning.pytorch as pl
19
+ from torch.nn import functional as F
20
+
21
+
22
+
23
+ model_name = "pythia-160m"
24
+ name = "redpajama"
25
+
26
+ def _init_weights(module: nn.Module) -> None:
27
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
28
+ if isinstance(module, nn.Linear):
29
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
30
+ if module.bias is not None:
31
+ torch.nn.init.zeros_(module.bias)
32
+ elif isinstance(module, nn.Embedding):
33
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
34
+
35
+ config = Config.from_name(model_name)
36
+ model = GPT(config)
37
+
38
+ next(model.parameters()).sum() #-25 -2 -860
39
+
40
+ model.apply(_init_weights)
41
+ model.load_state_dict
42
+
43
+
44
+ checkpoint_dir = Path("out/redpajama/intermediate-ckpt-3_9.pth")
45
+ strategy = "auto"
46
+ devices = 1
47
+ precision = None
48
+
49
+
50
+ precision = get_default_supported_precision(training=False)
51
+ plugins = None
52
+ fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
53
+ fabric.launch()
54
+ fabric.print(f"Loading model {str(checkpoint_dir)!r} with {config.__dict__}", file=sys.stderr)
55
+
56
+ with fabric.init_module(empty_init=True), gptq_quantization(quantize=="gptq.int4"):
57
+ model = GPT(config)
58
+
59
+ model.eval()
60
+ model = fabric.setup_module(model)
61
+ load_checkpoint(fabric, model, checkpoint_dir)
62
+
63
+ tokenizer = Tokenizer(Path('tokenizer_config'))
64
+ encoded = tokenizer.encode(prompt, device=fabric.device)
65
+ prompt_length = encoded.size(0)
66
+ max_returned_tokens = prompt_length + max_new_tokens
67
+
68
+ with fabric.init_tensor():
69
+ # set the max_seq_length to limit the memory usage to what we need
70
+ model.max_seq_length = max_returned_tokens
71
+
72
+ @torch.inference_mode()
73
+ def generate(
74
+ model: GPT,
75
+ idx: torch.Tensor,
76
+ max_returned_tokens: int,
77
+ *,
78
+ temperature: float = 1.0,
79
+ top_k:int = None,
80
+ eos_id:int = None,
81
+ ) -> torch.Tensor:
82
+ """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
83
+
84
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
85
+
86
+ Args:
87
+ model: The model to use.
88
+ idx: Tensor of shape (T) with indices of the prompt sequence.
89
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
90
+ temperature: Scales the predicted logits by 1 / temperature.
91
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
92
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered.
93
+ """
94
+ T = idx.size(0)
95
+ assert max_returned_tokens > T
96
+ if model.max_seq_length < max_returned_tokens - 1:
97
+ # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
98
+ # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
99
+ # not support it to avoid negatively impacting the overall speed
100
+ raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
101
+
102
+ device, dtype = idx.device, idx.dtype
103
+ # create an empty tensor of the expected final shape and fill in the current tokens
104
+ empty = torch.empty(max_returned_tokens, dtype=dtype, device=device)
105
+ empty[:T] = idx
106
+ idx = empty
107
+ input_pos = torch.arange(0, T, device=device)
108
+
109
+ # generate up to a fixed number of tokens
110
+ for _ in range(max_returned_tokens - T):
111
+ x = idx.index_select(0, input_pos).view(1, -1)
112
+
113
+ # forward
114
+ logits = model(x, input_pos)
115
+ logits = logits[0, -1] / temperature
116
+
117
+ # optionally crop the logits to only the top k options
118
+ if top_k is not None:
119
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
120
+ logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
121
+
122
+ probs = torch.nn.functional.softmax(logits, dim=-1)
123
+ idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
124
+
125
+ # advance
126
+ input_pos = input_pos[-1:] + 1
127
+
128
+ # concatenate the new generation
129
+ idx = idx.index_copy(0, input_pos, idx_next)
130
+
131
+ # if <eos> token is triggered, return the output (stop generation)
132
+ if idx_next == eos_id:
133
+ return idx[:input_pos] # include the EOS token
134
+
135
+ return idx
136
+
137
+
138
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
139
+
140
+ def generate_dialogue(input_text, temperature=0.8, max_tokens=200, top_k=1):
141
+ encoded = tokenizer.encode(input_text, device=fabric.device)
142
+ max_returned_tokens = encoded.size(0) + max_tokens
143
+
144
+
145
+ with fabric.init_tensor():
146
+ # set the max_seq_length to limit the memory usage to what we need
147
+ model.max_seq_length = max_returned_tokens
148
+
149
+
150
+ with fabric.init_tensor():
151
+ model.set_kv_cache(batch_size=1)
152
+
153
+ y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)
154
+
155
+ return(tokenizer.decode(y))
156
+
157
+