Saif Rehman Nasir commited on
Commit
705eec3
·
1 Parent(s): 8ca4f8d

Add UI interface code

Browse files
Files changed (2) hide show
  1. app.py +32 -4
  2. model.py +1 -1
app.py CHANGED
@@ -1,7 +1,35 @@
1
  import gradio as gr
 
 
2
 
3
- def greet():
4
- return "Hello world!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from model import BigramLM, encode, decode
4
 
5
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
6
 
7
+ model = torch.load('saved_model.pth', map_location= torch.device(device))
8
+
9
+ def generate_text(context, num_of_tokens, temperature=1.0):
10
+ if context == None or context == '':
11
+ idx = torch.zeros((1,1), dtype=torch.long)
12
+ else:
13
+ idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0)
14
+
15
+ return decode(model.generate(idx, max_new_tokens=num_of_tokens,temperature=temperature)[0].tolist())
16
+
17
+
18
+ with gr.Blocks as demo:
19
+ gr.HTML("<h1 align='center'> Shakespeare Text Generator</h1>")
20
+
21
+ context = gr.Textbox(label = "Enter context (optional)")
22
+
23
+ with gr.Row():
24
+ num_of_tokens = gr.Number( label = "Max tokens to generate", value = 100)
25
+ tmp = gr.Slider(label= "Temperature", minimum = 0.0, maximum = 1.0, value = 1.0 )
26
+
27
+ inputs = [
28
+ context,
29
+ num_of_tokens,tmp
30
+ ]
31
+ generate_btn = gr.Button(value="Generate")
32
+ outputs = [gr.Textbox(label= "Generated text")]
33
+ generate_btn.click(fn = generate_text, inputs= inputs, outputs= outputs)
34
+
35
+ demo.launch()
model.py CHANGED
@@ -206,7 +206,7 @@ class BigramLM(nn.Module):
206
  # sample from the distribution (pick the best)
207
  idx_next = torch.multinomial(probs, num_samples=1)
208
  # GPT like output
209
- print(decode(idx_next[0].tolist()), end='')
210
  # append sampled index to running sequence
211
  idx = torch.cat((idx, idx_next), dim=1)
212
 
 
206
  # sample from the distribution (pick the best)
207
  idx_next = torch.multinomial(probs, num_samples=1)
208
  # GPT like output
209
+ #print(decode(idx_next[0].tolist()), end='')
210
  # append sampled index to running sequence
211
  idx = torch.cat((idx, idx_next), dim=1)
212