GPT2-Amharic / app.py
rasyosef's picture
Create app.py
c0a205f verified
raw
history blame
2.12 kB
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
model_id = "rasyosef/gpt2-small-amharic-128-v3"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
gpt2_am = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
def generate(prompt):
prompt_length = len(tokenizer.tokenize(prompt))
if prompt_length >= 128:
yield prompt + "\n\nPrompt is too long. It needs to be less than 128 tokens."
else:
max_new_tokens = max(0, 128 - prompt_length)
streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=300.0)
thread = Thread(
target=gpt2_am,
kwargs={
"text_inputs": prompt,
"max_new_tokens": max_new_tokens,
"temperature": 0.8,
"do_sample": True,
"top_k": 8,
"top_p": 0.8,
"repetition_penalty": 1.25,
"streamer": streamer
})
thread.start()
generated_text = ""
for word in streamer:
generated_text += word
response = generated_text.strip()
yield response
with gr.Blocks() as demo:
gr.Markdown("""
# GPT2 Amharic
This is a demo for a smaller version of the gpt2 decoder transformer model pretrained for 1.5 days on `290 million` tokens of **Amharic** text. The context size of `gpt2-small-amharic` is 128 tokens.
""")
prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", lines=4, interactive=True)
with gr.Row():
with gr.Column():
gen = gr.Button("Generate")
with gr.Column():
btn = gr.ClearButton([prompt])
gen.click(generate, inputs=[prompt], outputs=[prompt])
examples = gr.Examples(
examples=[
"የ አዲስ አበባ",
"በ ኢንግሊዝ ፕሪምየር ሊግ",
"ፕሬዚዳንት ዶናልድ ትራምፕ"
],
inputs=[prompt],
)
demo.queue().launch(debug=True)