File size: 3,301 Bytes
96c86c7
 
 
fa1cc91
96c86c7
 
 
 
 
 
 
 
 
 
 
9053f66
96c86c7
 
 
 
 
 
 
 
 
 
 
 
325822e
 
 
 
 
96c86c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92e18a8
96c86c7
 
 
 
 
 
 
 
 
 
 
 
0f4aa6d
96c86c7
 
 
 
 
 
 
 
 
fa1cc91
 
fff1b6d
 
 
 
96c86c7
 
 
 
 
 
 
 
0e3daa0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
from transformers import pipeline, GPT2TokenizerFast

modelId = "luel/gpt2-tigrinya-medium"
tokenizer = GPT2TokenizerFast.from_pretrained(modelId, model_max_length=128)

generator = pipeline("text-generation", model=modelId, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)

def generate_text(prompt, max_length, temperature):
    try:
        generated = generator(
            prompt,
            max_length=max_length,
            temperature=temperature,
            do_sample=True,
            repetition_penalty=1.4
        )
        return generated[0]['generated_text']
    except Exception as e:
        return f"Something went wrong, try again. Error: {str(e)}"

def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# Tigrinya Text Generator (GPT-2)")
        gr.Markdown(
            "This is a GPT-2 model trained from scratch on Tigrinya text data, primarily from news sources. "
            "Enter your Tigrinya text prompt and adjust the parameters to generate text."
        )
        gr.Markdown(
            "**Parameters:**\n"
            "- **Temperature**: Controls the creativity of the output. Lower values (e.g., 0.2) make the text more focused and predictable, while higher values (e.g., 0.8) make it more diverse and creative.\n"
            "- **Maximum Length**: Defines how long the generated text will be. Higher values may take more time to generate but provide longer and more detailed results."
        )
        
        with gr.Row():
            input_temperature = gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.7,
                step=0.1,
                label="Temperature",
            )
            input_max_length = gr.Slider(
                minimum=10,
                maximum=128,
                value=60,
                step=1,
                label="Maximum Length",
            )
        
        with gr.Row():
            with gr.Column(scale=1):
                input_prompt = gr.Textbox(
                    label="Enter your Tigrinya text prompt",
                    placeholder="ክልል α‰΅αŒαˆ«α‹­",
                    lines=5
                )
            
            with gr.Column(scale=1):
                output_text = gr.Textbox(
                    label="Generated Text",
                    lines=5,
                    interactive=True
                )
        
        with gr.Row():
            generate_btn = gr.Button("Generate", variant="primary")
            clear_btn = gr.ClearButton([input_prompt, output_text])
        
        generate_btn.click(
            fn=generate_text,
            inputs=[input_prompt, input_max_length, input_temperature],
            outputs=output_text
        )
        
        gr.Examples(
            examples=[
                ["ክልል α‰΅αŒαˆ«α‹­"],
                ["መረጻ ኣሜαˆͺካ"],
                ["αˆ°αˆ‹αˆ"],
                ["α‰΅αŠ«αˆ αŒ₯α‹•αŠ“ α‹“αˆˆαˆ"],
                ["αŒ₯αŠ•α‰³α‹Š ሡልጣነ"],
                ["α‹²αˆžαŠ­αˆ«αˆ²"]
            ],
            inputs=input_prompt
        )
    
    return demo            

if __name__ == "__main__":
    demo = create_interface()
    demo.queue().launch(debug=True)