import gradio as gr from threading import Thread from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer model_id = "rasyosef/gpt2-small-amharic-128-v3" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) gpt2_am = pipeline( "text-generation", model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) def generate(prompt): prompt_length = len(tokenizer.tokenize(prompt)) if prompt_length >= 128: yield prompt + "\n\nPrompt is too long. It needs to be less than 128 tokens." else: max_new_tokens = max(0, 128 - prompt_length) streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=300.0) thread = Thread( target=gpt2_am, kwargs={ "text_inputs": prompt, "max_new_tokens": max_new_tokens, "temperature": 0.8, "do_sample": True, "top_k": 8, "top_p": 0.8, "repetition_penalty": 1.25, "streamer": streamer }) thread.start() generated_text = "" for word in streamer: generated_text += word response = generated_text.strip() yield response with gr.Blocks() as demo: gr.Markdown(""" # GPT2 Amharic This is a demo for a smaller version of the gpt2 decoder transformer model pretrained for 1.5 days on `290 million` tokens of **Amharic** text. The context size of `gpt2-small-amharic` is 128 tokens. """) prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", lines=4, interactive=True) with gr.Row(): with gr.Column(): gen = gr.Button("Generate") with gr.Column(): btn = gr.ClearButton([prompt]) gen.click(generate, inputs=[prompt], outputs=[prompt]) examples = gr.Examples( examples=[ "የ አዲስ አበባ", "በ ኢንግሊዝ ፕሪምየር ሊግ", "ፕሬዚዳንት ዶናልድ ትራምፕ" ], inputs=[prompt], ) demo.queue().launch(debug=True)