import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline from threading import Thread # Model Initialization model_id = "rasyosef/Llama-3.2-180M-Amharic-Instruct" st.title("Llama 3.2 180M Amharic Chatbot Demo") st.write(""" This chatbot was created using [Llama-3.2-180M-Amharic-Instruct](https://huggingface.co/rasyosef/Llama-3.2-180M-Amharic-Instruct), a finetuned version of the 180 million parameter Llama 3.2 Amharic transformer model. """) # Load the tokenizer and model @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) llama_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) return tokenizer, llama_pipeline tokenizer, llama_pipeline = load_model() # Generate text def generate_response(prompt, chat_history, max_new_tokens): history = [] # Build chat history for sent, received in chat_history: history.append({"role": "user", "content": sent}) history.append({"role": "assistant", "content": received}) history.append({"role": "user", "content": prompt}) if len(tokenizer.apply_chat_template(history)) > 512: return "Chat history is too long." else: streamer = TextIteratorStreamer( tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0 ) thread = Thread(target=llama_pipeline, kwargs={ "text_inputs": history, "max_new_tokens": max_new_tokens, "repetition_penalty": 1.15, "streamer": streamer }) thread.start() generated_text = "" for word in streamer: generated_text += word response = generated_text.strip() yield response # Streamlit Input and Chat Interface st.sidebar.header("Chatbot Configuration") max_tokens = st.sidebar.slider("Maximum new tokens", min_value=8, max_value=256, value=64, help="Larger values result in longer responses.") st.subheader("Chat with the Amharic Chatbot") chat_history = st.session_state.get("chat_history", []) # User Input user_input = st.text_input("Your message:", placeholder="Type your message here...") if st.button("Send"): if user_input: st.session_state.chat_history = st.session_state.get("chat_history", []) st.session_state.chat_history.append((user_input, "")) responses = generate_response(user_input, st.session_state.chat_history, max_tokens) # Stream output with st.spinner("Generating response..."): final_response = "" for response in responses: final_response = response st.session_state.chat_history[-1] = (user_input, final_response) st.experimental_rerun() # Display Chat History if "chat_history" in st.session_state: for i, (user_msg, bot_response) in enumerate(st.session_state.chat_history): st.write(f"**User {i+1}:** {user_msg}") st.write(f"**Bot:** {bot_response}")