ShakespeareGPT / app.py
nikhiljais's picture
Create app.py
6e8ccfb verified
import torch
import gradio as gr
from config.model_config import ModelConfig
from src.data.tokenizer import CharacterTokenizer
from src.utils.helpers import generate, setup_logging
# Setup logging
logger = setup_logging()
def load_model():
config = ModelConfig()
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# Load tokenizer
with open(config.data_path) as f:
text = f.read()
tokenizer = CharacterTokenizer(text)
# Load model
try:
model = torch.load(config.checkpoint_path, map_location=device)
model.eval()
return model, tokenizer, device
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def generate_text(prompt, max_tokens=200, temperature=0.8):
try:
result = generate(model, tokenizer, prompt, max_tokens, device)
return prompt + result
except Exception as e:
logger.error(f"Error during generation: {e}")
return f"Error: {str(e)}"
# Load model globally
try:
model, tokenizer, device = load_model()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
# Create Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Enter your prompt", placeholder="Type your prompt here..."),
gr.Slider(minimum=10, maximum=1000, value=200, step=10, label="Max Tokens"),
],
outputs=gr.Textbox(label="Generated Text"),
title="Shakespeare GPT",
description="Enter a prompt and generate text using a custom GPT model",
examples=[
["Hello, my name is", 200, 0.8],
["Once upon a time", 500, 0.8],
["The meaning of life is", 300, 0.8],
],
)
if __name__ == "__main__":
demo.launch()