SeemG commited on
Commit
b9b54b8
·
verified ·
1 Parent(s): a9825df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -30
app.py CHANGED
@@ -1,30 +1,30 @@
1
- import torch
2
- from model import BigramLanguageModel, decode
3
- import gradio as gr
4
-
5
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
-
7
- model = BigramLanguageModel()
8
- model.load_state_dict(torch.load("/content/drive/MyDrive/ERA V2/S19/neo_gpt.pth", map_location=device))
9
- def generate_text(max_new_tokens):
10
- context = torch.zeros((1, 1), dtype=torch.long)
11
- return decode(model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
12
-
13
-
14
- # Define the application components
15
- title = "Text Generation: Write Like Shakespeare"
16
- description = "This Gradio app uses a large language model (LLM) to generate text in the style of William Shakespeare."
17
-
18
-
19
- # Create a Gradio interface
20
- g_app = gr.Interface(
21
- fn = generate_text,
22
- inputs = [gr.Number(value = 10,label = "Number of Output Tokens",info = "Specify the desired length of the text to be generated.")],
23
- outputs = [gr.TextArea(lines = 5,label="Generated Text")],
24
- title = title,
25
- description = description
26
-
27
- )
28
-
29
- # Launch the Gradio app
30
- g_app.launch()
 
1
+ import torch
2
+ from model import BigramLanguageModel, decode
3
+ import gradio as gr
4
+
5
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
+
7
+ model = BigramLanguageModel()
8
+ model.load_state_dict(torch.load("./neo_gpt.pth", map_location=device))
9
+ def generate_text(max_new_tokens):
10
+ context = torch.zeros((1, 1), dtype=torch.long)
11
+ return decode(model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
12
+
13
+
14
+ # Define the application components
15
+ title = "Text Generation: Write Like Shakespeare"
16
+ description = "This Gradio app uses a large language model (LLM) to generate text in the style of William Shakespeare."
17
+
18
+
19
+ # Create a Gradio interface
20
+ g_app = gr.Interface(
21
+ fn = generate_text,
22
+ inputs = [gr.Number(value = 10,label = "Number of Output Tokens",info = "Specify the desired length of the text to be generated.")],
23
+ outputs = [gr.TextArea(lines = 5,label="Generated Text")],
24
+ title = title,
25
+ description = description
26
+
27
+ )
28
+
29
+ # Launch the Gradio app
30
+ g_app.launch()