Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import BertTokenizer, TFGPT2LMHeadModel | |
from transformers import TextGenerationPipeline | |
checkpoint = "mymusise/EasternFantasyNoval" | |
tokenizer = BertTokenizer.from_pretrained(checkpoint) | |
model = TFGPT2LMHeadModel.from_pretrained(checkpoint) | |
text_generater = TextGenerationPipeline(model, tokenizer) | |
def generate(prefix, length): | |
global text_generater | |
max_length = int(length) + len(prefix) | |
output = text_generater(prefix, max_length=max_length, do_sample=True, top_k=8) | |
generated_text = output[0]["generated_text"] | |
rs = ( "".join( generated_text ) ).replace(" ", "") | |
return rs | |
if __name__ == '__main__': | |
gr.Interface( | |
fn=generate, | |
inputs= [ | |
gr.Textbox(lines=10, placeholder="在这里输入一个开头。"), | |
"number" | |
], | |
outputs=gr.Textbox(lines=12, placeholder="这里会输出一段文å—。") | |
).launch() | |