nikhiljais commited on
Commit
6e8ccfb
·
verified ·
1 Parent(s): 70b7e96

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from config.model_config import ModelConfig
4
+ from src.data.tokenizer import CharacterTokenizer
5
+ from src.utils.helpers import generate, setup_logging
6
+
7
+ # Setup logging
8
+ logger = setup_logging()
9
+
10
+
11
+ def load_model():
12
+ config = ModelConfig()
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ logger.info(f"Using device: {device}")
15
+
16
+ # Load tokenizer
17
+ with open(config.data_path) as f:
18
+ text = f.read()
19
+ tokenizer = CharacterTokenizer(text)
20
+
21
+ # Load model
22
+ try:
23
+ model = torch.load(config.checkpoint_path, map_location=device)
24
+ model.eval()
25
+ return model, tokenizer, device
26
+ except Exception as e:
27
+ logger.error(f"Error loading model: {e}")
28
+ raise
29
+
30
+
31
+ def generate_text(prompt, max_tokens=200, temperature=0.8):
32
+ try:
33
+ result = generate(model, tokenizer, prompt, max_tokens, device)
34
+ return prompt + result
35
+ except Exception as e:
36
+ logger.error(f"Error during generation: {e}")
37
+ return f"Error: {str(e)}"
38
+
39
+
40
+ # Load model globally
41
+ try:
42
+ model, tokenizer, device = load_model()
43
+ logger.info("Model loaded successfully")
44
+ except Exception as e:
45
+ logger.error(f"Failed to load model: {e}")
46
+ raise
47
+
48
+ # Create Gradio interface
49
+ demo = gr.Interface(
50
+ fn=generate_text,
51
+ inputs=[
52
+ gr.Textbox(label="Enter your prompt", placeholder="Type your prompt here..."),
53
+ gr.Slider(minimum=10, maximum=1000, value=200, step=10, label="Max Tokens"),
54
+ ],
55
+ outputs=gr.Textbox(label="Generated Text"),
56
+ title="Shakespeare GPT",
57
+ description="Enter a prompt and generate text using a custom GPT model",
58
+ examples=[
59
+ ["Hello, my name is", 200, 0.8],
60
+ ["Once upon a time", 500, 0.8],
61
+ ["The meaning of life is", 300, 0.8],
62
+ ],
63
+ )
64
+
65
+ if __name__ == "__main__":
66
+ demo.launch()