lewiswu1209's picture
initial commit
1e5785c
raw
history blame contribute delete
937 Bytes
import gradio as gr
from transformers import BertTokenizer, TFGPT2LMHeadModel
from transformers import TextGenerationPipeline
checkpoint = "mymusise/EasternFantasyNoval"
tokenizer = BertTokenizer.from_pretrained(checkpoint)
model = TFGPT2LMHeadModel.from_pretrained(checkpoint)
text_generater = TextGenerationPipeline(model, tokenizer)
def generate(prefix, length):
global text_generater
max_length = int(length) + len(prefix)
output = text_generater(prefix, max_length=max_length, do_sample=True, top_k=8)
generated_text = output[0]["generated_text"]
rs = ( "".join( generated_text ) ).replace(" ", "")
return rs
if __name__ == '__main__':
gr.Interface(
fn=generate,
inputs= [
gr.Textbox(lines=10, placeholder="在这里输入一个开头。"),
"number"
],
outputs=gr.Textbox(lines=12, placeholder="这里会输出一段文字。")
).launch()