|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline |
|
from threading import Thread |
|
|
|
|
|
model_id = "rasyosef/Llama-3.2-180M-Amharic-Instruct" |
|
|
|
st.title("Llama 3.2 180M Amharic Chatbot Demo") |
|
st.write(""" |
|
This chatbot was created using [Llama-3.2-180M-Amharic-Instruct](https://huggingface.co/rasyosef/Llama-3.2-180M-Amharic-Instruct), |
|
a finetuned version of the 180 million parameter Llama 3.2 Amharic transformer model. |
|
""") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained(model_id) |
|
llama_pipeline = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
return tokenizer, llama_pipeline |
|
|
|
tokenizer, llama_pipeline = load_model() |
|
|
|
|
|
def generate_response(prompt, chat_history, max_new_tokens): |
|
history = [] |
|
|
|
|
|
for sent, received in chat_history: |
|
history.append({"role": "user", "content": sent}) |
|
history.append({"role": "assistant", "content": received}) |
|
|
|
history.append({"role": "user", "content": prompt}) |
|
|
|
if len(tokenizer.apply_chat_template(history)) > 512: |
|
return "Chat history is too long." |
|
else: |
|
streamer = TextIteratorStreamer( |
|
tokenizer=tokenizer, |
|
skip_prompt=True, |
|
skip_special_tokens=True, |
|
timeout=300.0 |
|
) |
|
thread = Thread(target=llama_pipeline, kwargs={ |
|
"text_inputs": history, |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": 1.15, |
|
"streamer": streamer |
|
}) |
|
thread.start() |
|
|
|
generated_text = "" |
|
for word in streamer: |
|
generated_text += word |
|
response = generated_text.strip() |
|
yield response |
|
|
|
|
|
st.sidebar.header("Chatbot Configuration") |
|
max_tokens = st.sidebar.slider("Maximum new tokens", min_value=8, max_value=256, value=64, help="Larger values result in longer responses.") |
|
|
|
st.subheader("Chat with the Amharic Chatbot") |
|
chat_history = st.session_state.get("chat_history", []) |
|
|
|
|
|
user_input = st.text_input("Your message:", placeholder="Type your message here...") |
|
|
|
if st.button("Send"): |
|
if user_input: |
|
st.session_state.chat_history = st.session_state.get("chat_history", []) |
|
st.session_state.chat_history.append((user_input, "")) |
|
responses = generate_response(user_input, st.session_state.chat_history, max_tokens) |
|
|
|
|
|
with st.spinner("Generating response..."): |
|
final_response = "" |
|
for response in responses: |
|
final_response = response |
|
st.session_state.chat_history[-1] = (user_input, final_response) |
|
st.experimental_rerun() |
|
|
|
|
|
if "chat_history" in st.session_state: |
|
for i, (user_msg, bot_response) in enumerate(st.session_state.chat_history): |
|
st.write(f"**User {i+1}:** {user_msg}") |
|
st.write(f"**Bot:** {bot_response}") |
|
|
|
|