LexUz-GPT / gpt_load.py
Dovud-Asadov's picture
load model
848ef8c verified
import torch
import gradio as gr
import tiktoken # Import tiktoken for GPT-2 tokenization
from gpt_parts import GPTModel # Ensure gpt_parts.py contains your GPTModel definition
# Configuration for GPT-2 model, same as used during training
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
# Initialize the tokenizer using tiktoken's GPT-2 encoding
tokenizer = tiktoken.get_encoding("gpt2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModel(GPT_CONFIG_124M).to(device)
model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
model.eval() # Set model to evaluation mode
def text_to_token_ids(text, tokenizer):
"""Encode text to token IDs."""
encoded = tokenizer.encode(text)
return torch.tensor(encoded).unsqueeze(0)
def token_ids_to_text(token_ids, tokenizer):
"""Decode token IDs to text."""
return tokenizer.decode(token_ids.squeeze(0).tolist())
def generate_text_simple(model, idx, max_new_tokens, context_size):
"""Autoregressively generate new tokens."""
for _ in range(max_new_tokens):
idx_cond = idx[:, -context_size:]
with torch.no_grad():
logits = model(idx_cond)
logits = logits[:, -1, :]
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
idx = torch.cat((idx, idx_next), dim=1)
return idx
# Define text generation function for Gradio
def generate_text(start_context, max_new_tokens=50):
# Encode the starting context
encoded_input = text_to_token_ids(start_context, tokenizer).to(device)
# Generate text
generated_token_ids = generate_text_simple(
model=model,
idx=encoded_input,
max_new_tokens=max_new_tokens,
context_size=GPT_CONFIG_124M["context_length"]
)
# Decode the generated tokens to text
generated_text = token_ids_to_text(generated_token_ids, tokenizer)
return generated_text.replace("\n", " ")
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=2, placeholder="Enter starting text here...", label="Start Context"),
gr.Slider(minimum=1, maximum=100, step=1, label="Max New Tokens")
],
outputs="text",
title="GPT-2 Text Generation",
description="Generate text using a fine-tuned GPT-2 model. Enter some starting text, and choose the maximum number of tokens to generate."
)
iface.launch(share=True)