Saif Rehman Nasir commited on
Commit
b9f58f3
·
1 Parent(s): 1a0060e

Add streaming output logic

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. model.py +2 -2
app.py CHANGED
@@ -12,7 +12,8 @@ def generate_text(context, num_of_tokens, temperature=1.0):
12
  else:
13
  idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0)
14
 
15
- return decode(model.generate(idx, max_new_tokens=num_of_tokens,temperature=temperature)[0].tolist())
 
16
 
17
 
18
  with gr.Blocks() as demo:
 
12
  else:
13
  idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0)
14
 
15
+ yield model.generate(idx, max_new_tokens=num_of_tokens,temperature=temperature)[0].tolist()
16
+
17
 
18
 
19
  with gr.Blocks() as demo:
model.py CHANGED
@@ -206,11 +206,11 @@ class BigramLM(nn.Module):
206
  # sample from the distribution (pick the best)
207
  idx_next = torch.multinomial(probs, num_samples=1)
208
  # GPT like output
209
- #print(decode(idx_next[0].tolist()), end='')
210
  # append sampled index to running sequence
211
  idx = torch.cat((idx, idx_next), dim=1)
212
 
213
- return idx
214
 
215
  def train():
216
 
 
206
  # sample from the distribution (pick the best)
207
  idx_next = torch.multinomial(probs, num_samples=1)
208
  # GPT like output
209
+ yield decode(idx_next[0].tolist())
210
  # append sampled index to running sequence
211
  idx = torch.cat((idx, idx_next), dim=1)
212
 
213
+ yield decode(idx_next[0].tolist())
214
 
215
  def train():
216