import torch import gradio as gr from config.model_config import ModelConfig from src.data.tokenizer import CharacterTokenizer from src.utils.helpers import generate, setup_logging # Setup logging logger = setup_logging() def load_model(): config = ModelConfig() device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # Load tokenizer with open(config.data_path) as f: text = f.read() tokenizer = CharacterTokenizer(text) # Load model try: model = torch.load(config.checkpoint_path, map_location=device) model.eval() return model, tokenizer, device except Exception as e: logger.error(f"Error loading model: {e}") raise def generate_text(prompt, max_tokens=200, temperature=0.8): try: result = generate(model, tokenizer, prompt, max_tokens, device) return prompt + result except Exception as e: logger.error(f"Error during generation: {e}") return f"Error: {str(e)}" # Load model globally try: model, tokenizer, device = load_model() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise # Create Gradio interface demo = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(label="Enter your prompt", placeholder="Type your prompt here..."), gr.Slider(minimum=10, maximum=1000, value=200, step=10, label="Max Tokens"), ], outputs=gr.Textbox(label="Generated Text"), title="Shakespeare GPT", description="Enter a prompt and generate text using a custom GPT model", examples=[ ["Hello, my name is", 200, 0.8], ["Once upon a time", 500, 0.8], ["The meaning of life is", 300, 0.8], ], ) if __name__ == "__main__": demo.launch()