Spaces:
Sleeping
Sleeping
import torch | |
from config.model_config import ModelConfig | |
from src.data.tokenizer import CharacterTokenizer | |
from src.utils.helpers import generate, setup_logging | |
def main(): | |
# Setup logging | |
logger = setup_logging() | |
# Load config | |
config = ModelConfig() | |
# Setup device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
# Load tokenizer | |
with open(config.data_path) as f: | |
text = f.read() | |
tokenizer = CharacterTokenizer(text) | |
# Load trained model | |
try: | |
model = torch.load(config.checkpoint_path, map_location=device) | |
model.eval() | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
return | |
# Generate text from prompts | |
while True: | |
try: | |
prompt = input("\nEnter a prompt (or 'quit' to exit): ") | |
if prompt.lower() == "quit": | |
break | |
max_tokens = 200 | |
logger.info("\nGenerating...") | |
result = generate(model, tokenizer, prompt, max_tokens, device) | |
logger.info("\nGenerated text:") | |
logger.info("=" * 50) | |
logger.info(prompt + result) | |
logger.info("=" * 50) | |
except KeyboardInterrupt: | |
logger.info("\nExiting...") | |
break | |
except Exception as e: | |
logger.error(f"Error during generation: {e}") | |
continue | |
if __name__ == "__main__": | |
main() | |