xRunda commited on
Commit
7c12b15
·
1 Parent(s): 5dcfa45
Files changed (1) hide show
  1. sample.py +12 -7
sample.py CHANGED
@@ -7,6 +7,7 @@ from contextlib import nullcontext
7
  import torch
8
  import tiktoken
9
  from model import GPTConfig, GPT
 
10
 
11
  # -----------------------------------------------------------------------------
12
  init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
@@ -80,10 +81,14 @@ if start.startswith('FILE:'):
80
  start_ids = encode(start)
81
  x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
82
 
83
- # run generation
84
- with torch.no_grad():
85
- with ctx:
86
- for k in range(num_samples):
87
- y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
88
- print(decode(y[0].tolist()))
89
- print('---------------')
 
 
 
 
 
7
  import torch
8
  import tiktoken
9
  from model import GPTConfig, GPT
10
+ import streamlit as st
11
 
12
  # -----------------------------------------------------------------------------
13
  init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
 
81
  start_ids = encode(start)
82
  x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
83
 
84
+ def main():
85
+ st.header("古诗词GPT")
86
+ # run generation
87
+ with torch.no_grad():
88
+ with ctx:
89
+ for k in range(num_samples):
90
+ y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
91
+ st.write(decode(y[0].tolist()))
92
+ st.write('---------------')
93
+ if __name__ == "__main__":
94
+ main()