File size: 1,479 Bytes
b57fe5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from config.model_config import ModelConfig
from src.data.tokenizer import CharacterTokenizer
from src.utils.helpers import generate, setup_logging


def main():
    # Setup logging
    logger = setup_logging()

    # Load config
    config = ModelConfig()

    # Setup device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {device}")

    # Load tokenizer
    with open(config.data_path) as f:
        text = f.read()
    tokenizer = CharacterTokenizer(text)

    # Load trained model
    try:
        model = torch.load(config.checkpoint_path, map_location=device)
        model.eval()
    except Exception as e:
        logger.error(f"Error loading model: {e}")
        return

    # Generate text from prompts
    while True:
        try:
            prompt = input("\nEnter a prompt (or 'quit' to exit): ")
            if prompt.lower() == "quit":
                break

            max_tokens = 200

            logger.info("\nGenerating...")
            result = generate(model, tokenizer, prompt, max_tokens, device)
            logger.info("\nGenerated text:")
            logger.info("=" * 50)
            logger.info(prompt + result)
            logger.info("=" * 50)
        except KeyboardInterrupt:
            logger.info("\nExiting...")
            break
        except Exception as e:
            logger.error(f"Error during generation: {e}")
            continue


if __name__ == "__main__":
    main()