import torch from config.model_config import ModelConfig from src.data.tokenizer import CharacterTokenizer from src.utils.helpers import generate, setup_logging def main(): # Setup logging logger = setup_logging() # Load config config = ModelConfig() # Setup device device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # Load tokenizer with open(config.data_path) as f: text = f.read() tokenizer = CharacterTokenizer(text) # Load trained model try: model = torch.load(config.checkpoint_path, map_location=device) model.eval() except Exception as e: logger.error(f"Error loading model: {e}") return # Generate text from prompts while True: try: prompt = input("\nEnter a prompt (or 'quit' to exit): ") if prompt.lower() == "quit": break max_tokens = 200 logger.info("\nGenerating...") result = generate(model, tokenizer, prompt, max_tokens, device) logger.info("\nGenerated text:") logger.info("=" * 50) logger.info(prompt + result) logger.info("=" * 50) except KeyboardInterrupt: logger.info("\nExiting...") break except Exception as e: logger.error(f"Error during generation: {e}") continue if __name__ == "__main__": main()