File size: 2,698 Bytes
96c86c7 0f4aa6d 96c86c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
from transformers import pipeline, GPT2TokenizerFast
modelId = "luel/gpt2-tigrinya-small"
tokenizer = GPT2TokenizerFast.from_pretrained(modelId, model_max_length=128)
generator = pipeline("text-generation", model=modelId, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
def generate_text(prompt, max_length, temperature):
try:
generated = generator(
prompt,
max_length=max_length,
temperature=temperature,
do_sample=True,
repetition_penalty=1.5
)
return generated[0]['generated_text']
except Exception as e:
return f"Something went wrong, try again. Error: {str(e)}"
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# Tigrinya Text Generator (GPT-2)")
gr.Markdown(
"This is a GPT-2 model trained from scratch on Tigrinya text data, primarily from news sources. "
"Enter your Tigrinya text prompt and adjust the parameters to generate text."
)
with gr.Row():
input_temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
)
input_max_length = gr.Slider(
minimum=10,
maximum=128,
value=60,
step=1,
label="Maximum Length",
)
with gr.Row():
with gr.Column(scale=1):
input_prompt = gr.Textbox(
label="Enter your Tigrinya text prompt",
placeholder="α΅αα«α",
lines=5
)
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Generated Text",
lines=5,
interactive=True
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
clear_btn = gr.ClearButton([input_prompt, output_text])
generate_btn.click(
fn=generate_text,
inputs=[input_prompt, input_max_length, input_temperature],
outputs=output_text
)
gr.Examples(
examples=[
["α΅αα«α"],
["α£α²α΅ α£α α£"],
["α°αα"]
],
inputs=input_prompt
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.queue().launch(debug=True) |