am_con / app.py
berito's picture
app edited
0c03e17 verified
raw
history blame
3.27 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline
from threading import Thread
# Model Initialization
model_id = "rasyosef/Llama-3.2-180M-Amharic-Instruct"
st.title("Llama 3.2 180M Amharic Chatbot Demo")
st.write("""
This chatbot was created using [Llama-3.2-180M-Amharic-Instruct](https://huggingface.co/rasyosef/Llama-3.2-180M-Amharic-Instruct),
a finetuned version of the 180 million parameter Llama 3.2 Amharic transformer model.
""")
# Load the tokenizer and model
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
llama_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer, llama_pipeline
tokenizer, llama_pipeline = load_model()
# Generate text
def generate_response(prompt, chat_history, max_new_tokens):
history = []
# Build chat history
for sent, received in chat_history:
history.append({"role": "user", "content": sent})
history.append({"role": "assistant", "content": received})
history.append({"role": "user", "content": prompt})
if len(tokenizer.apply_chat_template(history)) > 512:
return "Chat history is too long."
else:
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=300.0
)
thread = Thread(target=llama_pipeline, kwargs={
"text_inputs": history,
"max_new_tokens": max_new_tokens,
"repetition_penalty": 1.15,
"streamer": streamer
})
thread.start()
generated_text = ""
for word in streamer:
generated_text += word
response = generated_text.strip()
yield response
# Streamlit Input and Chat Interface
st.sidebar.header("Chatbot Configuration")
max_tokens = st.sidebar.slider("Maximum new tokens", min_value=8, max_value=256, value=64, help="Larger values result in longer responses.")
st.subheader("Chat with the Amharic Chatbot")
chat_history = st.session_state.get("chat_history", [])
# User Input
user_input = st.text_input("Your message:", placeholder="Type your message here...")
if st.button("Send"):
if user_input:
st.session_state.chat_history = st.session_state.get("chat_history", [])
st.session_state.chat_history.append((user_input, ""))
responses = generate_response(user_input, st.session_state.chat_history, max_tokens)
# Stream output
with st.spinner("Generating response..."):
final_response = ""
for response in responses:
final_response = response
st.session_state.chat_history[-1] = (user_input, final_response)
st.experimental_rerun()
# Display Chat History
if "chat_history" in st.session_state:
for i, (user_msg, bot_response) in enumerate(st.session_state.chat_history):
st.write(f"**User {i+1}:** {user_msg}")
st.write(f"**Bot:** {bot_response}")