Spaces:
Runtime error
Runtime error
File size: 6,376 Bytes
65d886a 6546822 65d886a b8c1cf7 65d886a 46a607d 65d886a 46a607d cc30221 46a607d 65d886a 46a607d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import time
import torch
import lightning as L
from torch.utils.data import DataLoader
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.strategies import FSDPStrategy
from tsai_gpt.model import GPT, Block, Config
from tsai_gpt.tokenizer import Tokenizer
from tsai_gpt.packed_dataset import CombinedDataset, PackedDataset
from tsai_gpt.speed_monitor import SpeedMonitorBase, estimate_flops, measure_flops
from tsai_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor
from tsai_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, load_checkpoint, gptq_quantization
import torch.nn as nn
from pathlib import Path
import sys
import random
from torch import nn
import lightning.pytorch as pl
from torch.nn import functional as F
model_name = "pythia-160m"
name = "redpajama"
def _init_weights(module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
config = Config.from_name(model_name)
model = GPT(config)
next(model.parameters()).sum() #-25 -2 -860
model.apply(_init_weights)
model.load_state_dict
checkpoint_dir = Path("final-gpt-model-ckpt.pth")
strategy = "auto"
quantize = None
devices = 1
precision = None
precision = get_default_supported_precision(training=False)
plugins = None
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
fabric.launch()
fabric.print(f"Loading model {str(checkpoint_dir)!r} with {config.__dict__}", file=sys.stderr)
with fabric.init_module(empty_init=True), gptq_quantization(quantize=="gptq.int4"):
model = GPT(config)
model.eval()
model = fabric.setup_module(model)
load_checkpoint(fabric, model, checkpoint_dir)
tokenizer = Tokenizer(Path('tokenizer_config'))
@torch.inference_mode()
def generate(
model: GPT,
idx: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k:int = None,
eos_id:int = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
"""
T = idx.size(0)
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
# not support it to avoid negatively impacting the overall speed
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(max_returned_tokens, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
input_pos = torch.arange(0, T, device=device)
# generate up to a fixed number of tokens
for _ in range(max_returned_tokens - T):
x = idx.index_select(0, input_pos).view(1, -1)
# forward
logits = model(x, input_pos)
logits = logits[0, -1] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
# advance
input_pos = input_pos[-1:] + 1
# concatenate the new generation
idx = idx.index_copy(0, input_pos, idx_next)
# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:input_pos] # include the EOS token
return idx
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def generate_text(input_text, temperature=0.8, max_tokens=200, top_k=None):
encoded = tokenizer.encode(input_text, device=fabric.device)
max_returned_tokens = encoded.size(0) + max_tokens
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
with fabric.init_tensor():
model.set_kv_cache(batch_size=1)
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)
return(tokenizer.decode(y))
import gradio as gr
title = "GPT from scratch"
description1 = "GPT implementation taken from <a href='https://github.com/Lightning-AI/lit-gpt'>Lit-GPT</a>. It is trained on a samples of the <a href='https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample'>RedPajama 1 trillion dataset</a> to understand how GPT's are trained and built. The github link can be found <a href='https://github.com/mkthoma/gpt_from_scratch'>here.</a>"
demo = gr.Interface(generate_text,
inputs=[gr.Textbox(label="Enter any prompt ", type="text", value="Once upon a time,"),
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature"),
gr.Slider(minimum=200, maximum=1000, step=50, value=300, label="Max Tokens"),
gr.Slider(minimum=10, maximum=100, step=5, value=20, label="Top K")],
outputs=gr.Textbox(label="Text generated", type="text"), title=title, description=description1)
demo.launch()
|