Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import tiktoken # Import tiktoken for GPT-2 tokenization | |
from gpt_parts import GPTModel # Ensure gpt_parts.py contains your GPTModel definition | |
# Configuration for GPT-2 model, same as used during training | |
GPT_CONFIG_124M = { | |
"vocab_size": 50257, # Vocabulary size | |
"context_length": 1024, # Context length | |
"emb_dim": 768, # Embedding dimension | |
"n_heads": 12, # Number of attention heads | |
"n_layers": 12, # Number of layers | |
"drop_rate": 0.1, # Dropout rate | |
"qkv_bias": False # Query-Key-Value bias | |
} | |
# Initialize the tokenizer using tiktoken's GPT-2 encoding | |
tokenizer = tiktoken.get_encoding("gpt2") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = GPTModel(GPT_CONFIG_124M).to(device) | |
model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True)) | |
model.eval() # Set model to evaluation mode | |
def text_to_token_ids(text, tokenizer): | |
"""Encode text to token IDs.""" | |
encoded = tokenizer.encode(text) | |
return torch.tensor(encoded).unsqueeze(0) | |
def token_ids_to_text(token_ids, tokenizer): | |
"""Decode token IDs to text.""" | |
return tokenizer.decode(token_ids.squeeze(0).tolist()) | |
def generate_text_simple(model, idx, max_new_tokens, context_size): | |
"""Autoregressively generate new tokens.""" | |
for _ in range(max_new_tokens): | |
idx_cond = idx[:, -context_size:] | |
with torch.no_grad(): | |
logits = model(idx_cond) | |
logits = logits[:, -1, :] | |
idx_next = torch.argmax(logits, dim=-1, keepdim=True) | |
idx = torch.cat((idx, idx_next), dim=1) | |
return idx | |
# Define text generation function for Gradio | |
def generate_text(start_context, max_new_tokens=50): | |
# Encode the starting context | |
encoded_input = text_to_token_ids(start_context, tokenizer).to(device) | |
# Generate text | |
generated_token_ids = generate_text_simple( | |
model=model, | |
idx=encoded_input, | |
max_new_tokens=max_new_tokens, | |
context_size=GPT_CONFIG_124M["context_length"] | |
) | |
# Decode the generated tokens to text | |
generated_text = token_ids_to_text(generated_token_ids, tokenizer) | |
return generated_text.replace("\n", " ") | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter starting text here...", label="Start Context"), | |
gr.Slider(minimum=1, maximum=100, step=1, label="Max New Tokens") | |
], | |
outputs="text", | |
title="GPT-2 Text Generation", | |
description="Generate text using a fine-tuned GPT-2 model. Enter some starting text, and choose the maximum number of tokens to generate." | |
) | |
iface.launch(share=True) | |