import gradio as gr from transformers import BertTokenizer, TFGPT2LMHeadModel from transformers import TextGenerationPipeline checkpoint = "mymusise/EasternFantasyNoval" tokenizer = BertTokenizer.from_pretrained(checkpoint) model = TFGPT2LMHeadModel.from_pretrained(checkpoint) text_generater = TextGenerationPipeline(model, tokenizer) def generate(prefix, length): global text_generater max_length = int(length) + len(prefix) output = text_generater(prefix, max_length=max_length, do_sample=True, top_k=8) generated_text = output[0]["generated_text"] rs = ( "".join( generated_text ) ).replace(" ", "") return rs if __name__ == '__main__': gr.Interface( fn=generate, inputs= [ gr.Textbox(lines=10, placeholder="在这里输入一个开头。"), "number" ], outputs=gr.Textbox(lines=12, placeholder="这里会输出一段文字。") ).launch()