Spaces:
No application file
No application file
xRunda
commited on
Commit
·
7c12b15
1
Parent(s):
5dcfa45
streamlit
Browse files
sample.py
CHANGED
@@ -7,6 +7,7 @@ from contextlib import nullcontext
|
|
7 |
import torch
|
8 |
import tiktoken
|
9 |
from model import GPTConfig, GPT
|
|
|
10 |
|
11 |
# -----------------------------------------------------------------------------
|
12 |
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
|
@@ -80,10 +81,14 @@ if start.startswith('FILE:'):
|
|
80 |
start_ids = encode(start)
|
81 |
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
7 |
import torch
|
8 |
import tiktoken
|
9 |
from model import GPTConfig, GPT
|
10 |
+
import streamlit as st
|
11 |
|
12 |
# -----------------------------------------------------------------------------
|
13 |
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
|
|
|
81 |
start_ids = encode(start)
|
82 |
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
|
83 |
|
84 |
+
def main():
|
85 |
+
st.header("古诗词GPT")
|
86 |
+
# run generation
|
87 |
+
with torch.no_grad():
|
88 |
+
with ctx:
|
89 |
+
for k in range(num_samples):
|
90 |
+
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
91 |
+
st.write(decode(y[0].tolist()))
|
92 |
+
st.write('---------------')
|
93 |
+
if __name__ == "__main__":
|
94 |
+
main()
|