app edited
Browse files
app.py
CHANGED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline
|
3 |
+
from threading import Thread
|
4 |
+
|
5 |
+
# Model Initialization
|
6 |
+
model_id = "rasyosef/Llama-3.2-180M-Amharic-Instruct"
|
7 |
+
|
8 |
+
st.title("Llama 3.2 180M Amharic Chatbot Demo")
|
9 |
+
st.write("""
|
10 |
+
This chatbot was created using [Llama-3.2-180M-Amharic-Instruct](https://huggingface.co/rasyosef/Llama-3.2-180M-Amharic-Instruct),
|
11 |
+
a finetuned version of the 180 million parameter Llama 3.2 Amharic transformer model.
|
12 |
+
""")
|
13 |
+
|
14 |
+
# Load the tokenizer and model
|
15 |
+
@st.cache_resource
|
16 |
+
def load_model():
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
19 |
+
llama_pipeline = pipeline(
|
20 |
+
"text-generation",
|
21 |
+
model=model,
|
22 |
+
tokenizer=tokenizer,
|
23 |
+
pad_token_id=tokenizer.pad_token_id,
|
24 |
+
eos_token_id=tokenizer.eos_token_id
|
25 |
+
)
|
26 |
+
return tokenizer, llama_pipeline
|
27 |
+
|
28 |
+
tokenizer, llama_pipeline = load_model()
|
29 |
+
|
30 |
+
# Generate text
|
31 |
+
def generate_response(prompt, chat_history, max_new_tokens):
|
32 |
+
history = []
|
33 |
+
|
34 |
+
# Build chat history
|
35 |
+
for sent, received in chat_history:
|
36 |
+
history.append({"role": "user", "content": sent})
|
37 |
+
history.append({"role": "assistant", "content": received})
|
38 |
+
|
39 |
+
history.append({"role": "user", "content": prompt})
|
40 |
+
|
41 |
+
if len(tokenizer.apply_chat_template(history)) > 512:
|
42 |
+
return "Chat history is too long."
|
43 |
+
else:
|
44 |
+
streamer = TextIteratorStreamer(
|
45 |
+
tokenizer=tokenizer,
|
46 |
+
skip_prompt=True,
|
47 |
+
skip_special_tokens=True,
|
48 |
+
timeout=300.0
|
49 |
+
)
|
50 |
+
thread = Thread(target=llama_pipeline, kwargs={
|
51 |
+
"text_inputs": history,
|
52 |
+
"max_new_tokens": max_new_tokens,
|
53 |
+
"repetition_penalty": 1.15,
|
54 |
+
"streamer": streamer
|
55 |
+
})
|
56 |
+
thread.start()
|
57 |
+
|
58 |
+
generated_text = ""
|
59 |
+
for word in streamer:
|
60 |
+
generated_text += word
|
61 |
+
response = generated_text.strip()
|
62 |
+
yield response
|
63 |
+
|
64 |
+
# Streamlit Input and Chat Interface
|
65 |
+
st.sidebar.header("Chatbot Configuration")
|
66 |
+
max_tokens = st.sidebar.slider("Maximum new tokens", min_value=8, max_value=256, value=64, help="Larger values result in longer responses.")
|
67 |
+
|
68 |
+
st.subheader("Chat with the Amharic Chatbot")
|
69 |
+
chat_history = st.session_state.get("chat_history", [])
|
70 |
+
|
71 |
+
# User Input
|
72 |
+
user_input = st.text_input("Your message:", placeholder="Type your message here...")
|
73 |
+
|
74 |
+
if st.button("Send"):
|
75 |
+
if user_input:
|
76 |
+
st.session_state.chat_history = st.session_state.get("chat_history", [])
|
77 |
+
st.session_state.chat_history.append((user_input, ""))
|
78 |
+
responses = generate_response(user_input, st.session_state.chat_history, max_tokens)
|
79 |
+
|
80 |
+
# Stream output
|
81 |
+
with st.spinner("Generating response..."):
|
82 |
+
final_response = ""
|
83 |
+
for response in responses:
|
84 |
+
final_response = response
|
85 |
+
st.session_state.chat_history[-1] = (user_input, final_response)
|
86 |
+
st.experimental_rerun()
|
87 |
+
|
88 |
+
# Display Chat History
|
89 |
+
if "chat_history" in st.session_state:
|
90 |
+
for i, (user_msg, bot_response) in enumerate(st.session_state.chat_history):
|
91 |
+
st.write(f"**User {i+1}:** {user_msg}")
|
92 |
+
st.write(f"**Bot:** {bot_response}")
|
93 |
+
|