am_con / app.py
berito's picture
Upload app.py
3b8e387 verified
raw
history blame
4.43 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline
from threading import Thread
# Model Initialization
model_id = "rasyosef/Llama-3.2-180M-Amharic-Instruct"
st.title("Llama 3.2 180M Amharic Chatbot Demo")
st.write("""
This chatbot was created using [Llama-3.2-180M-Amharic-Instruct](https://huggingface.co/rasyosef/Llama-3.2-180M-Amharic-Instruct),
a finetuned version of the 180 million parameter Llama 3.2 Amharic transformer model.
""")
# Load the tokenizer and model
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
llama_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer, llama_pipeline
tokenizer, llama_pipeline = load_model()
# Generate text
def generate_response(prompt, chat_history, max_new_tokens):
history = []
# Build chat history
for sent, received in chat_history:
history.append({"role": "user", "content": sent})
history.append({"role": "assistant", "content": received})
history.append({"role": "user", "content": prompt})
if len(tokenizer.apply_chat_template(history)) > 512:
return "Chat history is too long."
else:
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=300.0
)
thread = Thread(target=llama_pipeline, kwargs={
"text_inputs": history,
"max_new_tokens": max_new_tokens,
"repetition_penalty": 1.15,
"streamer": streamer
})
thread.start()
generated_text = ""
for word in streamer:
generated_text += word
response = generated_text.strip()
yield response
# Sidebar: Configuration
st.sidebar.header("Chatbot Configuration")
max_tokens = st.sidebar.slider("Maximum new tokens", min_value=8, max_value=256, value=64, help="Larger values result in longer responses.")
# Examples
examples = [
"แˆฐแˆ‹แˆแฃ แŠฅแŠ•แ‹ดแ‰ต แŠแˆ…?",
"แ‹จแŠขแ‰ตแ‹ฎแŒตแ‹ซ แ‹‹แŠ“ แŠจแ‰ฐแˆ› แˆตแˆ แˆแŠ•แ‹ตแŠ• แŠแ‹?",
"แ‹จแŠขแ‰ตแ‹ฎแŒตแ‹ซ แ‹จแˆ˜แŒจแˆจแˆปแ‹ แŠ•แŒ‰แˆต แˆ›แŠ• แŠแ‰ แˆฉ?",
"แ‹จแŠ แˆ›แˆญแŠ› แŒแŒฅแˆ แƒแแˆแŠ",
"แ‰ฐแˆจแ‰ต แŠ•แŒˆแˆจแŠ\n\nแŒ…แ‰ฅแŠ“ แŠ แŠ•แ‰ แˆณ",
"แŠ แŠ•แ‹ต แŠ แˆตแ‰‚แŠ แ‰€แˆแ‹ต แŠ•แŒˆแˆจแŠ",
"แ‹จแ‰ฐแˆฐแŒ แ‹ แŒฝแˆ‘แ แŠ แˆตแ‰ฐแ‹ซแ‹จแ‰ต แˆแŠ• แŠ แ‹ญแŠแ‰ต แŠแ‹? 'แŠ แ‹ŽแŠ•แ‰ณแ‹Š'แฃ 'แŠ แˆ‰แ‰ณแ‹Š' แ‹ˆแ‹ญแˆ 'แŒˆแˆˆแˆแ‰ฐแŠ›' แ‹จแˆšแˆ แˆแˆ‹แˆฝ แˆตแŒฅแข 'แŠ แˆชแ แŠแˆแˆ แŠแ‰ แˆญ'",
"แ‹จแˆแˆจแŠ•แˆณแ‹ญ แ‹‹แŠ“ แŠจแ‰ฐแˆ› แˆตแˆ แˆแŠ•แ‹ตแŠ• แŠแ‹?",
"แŠ แˆแŠ• แ‹จแŠ แˆœแˆชแŠซ แ•แˆฌแ‹šแ‹ณแŠ•แ‰ต แˆ›แŠ• แŠแ‹?",
"แˆถแˆตแ‰ต แ‹จแŠ แแˆชแŠซ แˆ€แŒˆแˆซแ‰ต แŒฅแ‰€แˆตแˆแŠ",
"3 แ‹จแŠ แˆœแˆชแŠซ แˆ˜แˆชแ‹Žแ‰ฝแŠ• แˆตแˆ แŒฅแ‰€แˆต",
"5 แ‹จแŠ แˆœแˆชแŠซ แŠจแ‰ฐแˆ›แ‹Žแ‰ฝแŠ• แŒฅแ‰€แˆต",
"แŠ แˆแˆตแ‰ต แ‹จแŠ แ‹แˆฎแ“ แˆ€แŒˆแˆฎแ‰ฝแŠ• แŒฅแ‰€แˆตแˆแŠ",
"แ‰  แ‹“แˆˆแˆ แˆ‹แ‹ญ แ‹ซแˆ‰แ‰ตแŠ• 7 แŠ แˆ…แŒ‰แˆซแ‰ต แŠ•แŒˆแˆจแŠ"
]
st.subheader("Chat with the Amharic Chatbot")
chat_history = st.session_state.get("chat_history", [])
# Example selector
example = st.selectbox("Choose an example:", ["Type your own message"] + examples)
# User Input
user_input = st.text_input("Your message:", value=example if example != "Type your own message" else "", placeholder="Type your message here...")
if st.button("Send"):
if user_input:
st.session_state.chat_history = st.session_state.get("chat_history", [])
st.session_state.chat_history.append((user_input, ""))
responses = generate_response(user_input, st.session_state.chat_history, max_tokens)
# Stream output
with st.spinner("Generating response..."):
final_response = ""
for response in responses:
final_response = response
st.session_state.chat_history[-1] = (user_input, final_response)
st.rerun()
# Display Chat History
if "chat_history" in st.session_state:
for i, (user_msg, bot_response) in enumerate(st.session_state.chat_history):
st.write(f"**User {i+1}:** {user_msg}")
st.write(f"**Bot:** {bot_response}")