Saif Rehman Nasir commited on
Commit
f8e3be7
·
1 Parent(s): 71c03ca

Revert streaming logic

Browse files
Files changed (2) hide show
  1. app.py +6 -10
  2. model.py +2 -2
app.py CHANGED
@@ -6,13 +6,13 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
 
7
  model = torch.load('saved_model.pth', map_location= torch.device(device), weights_only=False)
8
 
9
- # def generate_text(context, num_of_tokens, temperature=1.0):
10
- # if context == None or context == '':
11
- # idx = torch.zeros((1,1), dtype=torch.long)
12
- # else:
13
- # idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0)
14
 
15
- # yield model.generate(idx, max_new_tokens=num_of_tokens,temperature=temperature)
16
 
17
 
18
 
@@ -29,10 +29,6 @@ with gr.Blocks() as demo:
29
  context,
30
  num_of_tokens,tmp
31
  ]
32
- if context == None or context == '':
33
- idx = torch.zeros((1,1), dtype=torch.long)
34
- else:
35
- idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0)
36
  generate_btn = gr.Button(value="Generate")
37
  outputs = [gr.Textbox(label= "Generated text: ")]
38
  generate_btn.click(fn = model.generate, inputs= inputs, outputs= outputs)
 
6
 
7
  model = torch.load('saved_model.pth', map_location= torch.device(device), weights_only=False)
8
 
9
+ def generate_text(context, num_of_tokens, temperature=1.0):
10
+ if context == None or context == '':
11
+ idx = torch.zeros((1,1), dtype=torch.long)
12
+ else:
13
+ idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0)
14
 
15
+ yield model.generate(idx, max_new_tokens=num_of_tokens,temperature=temperature)
16
 
17
 
18
 
 
29
  context,
30
  num_of_tokens,tmp
31
  ]
 
 
 
 
32
  generate_btn = gr.Button(value="Generate")
33
  outputs = [gr.Textbox(label= "Generated text: ")]
34
  generate_btn.click(fn = model.generate, inputs= inputs, outputs= outputs)
model.py CHANGED
@@ -206,11 +206,11 @@ class BigramLM(nn.Module):
206
  # sample from the distribution (pick the best)
207
  idx_next = torch.multinomial(probs, num_samples=1)
208
  # GPT like output
209
- yield decode(idx_next[0].tolist())
210
  # append sampled index to running sequence
211
  idx = torch.cat((idx, idx_next), dim=1)
212
 
213
- yield decode(idx_next[0].tolist())
214
 
215
  def train():
216
 
 
206
  # sample from the distribution (pick the best)
207
  idx_next = torch.multinomial(probs, num_samples=1)
208
  # GPT like output
209
+ # yield decode(idx_next[0].tolist())
210
  # append sampled index to running sequence
211
  idx = torch.cat((idx, idx_next), dim=1)
212
 
213
+ return idx
214
 
215
  def train():
216