|
import gradio as gr |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline |
|
from threading import Thread |
|
|
|
model_id = "rasyosef/Llama-3.2-400M-Amharic-Instruct-Poems-Stories-Wikipedia" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float32, |
|
|
|
) |
|
|
|
llama3_am = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
eos_token_id=tokenizer.eos_token_id, |
|
|
|
) |
|
|
|
|
|
def generate(message, chat_history, max_new_tokens=64): |
|
|
|
history = [] |
|
|
|
for sent, received in chat_history: |
|
history.append({"role": "user", "content": sent}) |
|
history.append({"role": "assistant", "content": received}) |
|
|
|
history.append({"role": "user", "content": message}) |
|
|
|
if len(tokenizer.apply_chat_template(history)) > 512: |
|
yield "chat history is too long" |
|
else: |
|
|
|
streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0) |
|
thread = Thread( |
|
target=llama3_am, |
|
kwargs={ |
|
"text_inputs":history, |
|
"max_new_tokens":max_new_tokens, |
|
"repetition_penalty":1.1, |
|
"streamer":streamer |
|
} |
|
) |
|
thread.start() |
|
|
|
generated_text = "" |
|
for word in streamer: |
|
generated_text += word |
|
response = generated_text.strip() |
|
|
|
yield response |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# Llama 3.2 400M Amharic Chatbot Demo |
|
""") |
|
|
|
tokens_slider = gr.Slider(8, 256, value=64, label="Maximum new tokens", info="A larger `max_new_tokens` parameter value gives you longer text responses but at the cost of a slower response time.") |
|
|
|
chatbot = gr.ChatInterface( |
|
chatbot=gr.Chatbot(height=400), |
|
fn=generate, |
|
additional_inputs=[tokens_slider], |
|
stop_btn=None, |
|
cache_examples=False, |
|
examples=[ |
|
["แฐแแ"], |
|
["แฐแแแฃ แฅแแดแต แแ
?"], |
|
["แ แแฐ แแแ
?"], |
|
["แแฅแ แแแแ"], |
|
["แตแ แญแ
แญแณ แแฅแ แปแแแ"], |
|
["แ แแต แฐแจแต แ แซแแฐแ"], |
|
["แตแ แ
แฅแ แ แแ แณ แฐแจแต แแแจแ"], |
|
["แแแต แแแจแ"], |
|
["แตแ แตแซ แ แฅแแต แ แแต แแแต แแแจแ"], |
|
["แณแแแ แดแแตแฎแต แแ แแ?"], |
|
["แณแแแ แแแแญ แแ แแ?"], |
|
["แตแ แ แฒแต แ แ แฃ แฉแแจแญแตแฒ แฅแแต แฅแแแณแแฝแ แ แซแแฐแ"], |
|
["แตแ แแแ แฅแแต แฅแแแณแแฝแ แแแจแ"], |
|
["แตแ แแญแญแฎแถแแต แฅแแต แฅแแแณแแฝแ แแแจแ"], |
|
["แแแ แแแตแ แแ?"], |
|
["แขแตแฎแญแ แแแตแ แแ?"], |
|
] |
|
) |
|
|
|
demo.queue().launch(debug=True) |