Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from config.model_config import ModelConfig | |
from src.data.tokenizer import CharacterTokenizer | |
from src.utils.helpers import generate, setup_logging | |
# Setup logging | |
logger = setup_logging() | |
def load_model(): | |
config = ModelConfig() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
# Load tokenizer | |
with open(config.data_path) as f: | |
text = f.read() | |
tokenizer = CharacterTokenizer(text) | |
# Load model | |
try: | |
model = torch.load(config.checkpoint_path, map_location=device) | |
model.eval() | |
return model, tokenizer, device | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
raise | |
def generate_text(prompt, max_tokens=200, temperature=0.8): | |
try: | |
result = generate(model, tokenizer, prompt, max_tokens, device) | |
return prompt + result | |
except Exception as e: | |
logger.error(f"Error during generation: {e}") | |
return f"Error: {str(e)}" | |
# Load model globally | |
try: | |
model, tokenizer, device = load_model() | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load model: {e}") | |
raise | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(label="Enter your prompt", placeholder="Type your prompt here..."), | |
gr.Slider(minimum=10, maximum=1000, value=200, step=10, label="Max Tokens"), | |
], | |
outputs=gr.Textbox(label="Generated Text"), | |
title="Shakespeare GPT", | |
description="Enter a prompt and generate text using a custom GPT model", | |
examples=[ | |
["Hello, my name is", 200, 0.8], | |
["Once upon a time", 500, 0.8], | |
["The meaning of life is", 300, 0.8], | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |