berito commited on
Commit
0c03e17
·
verified ·
1 Parent(s): 55302f2

app edited

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py CHANGED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline
3
+ from threading import Thread
4
+
5
+ # Model Initialization
6
+ model_id = "rasyosef/Llama-3.2-180M-Amharic-Instruct"
7
+
8
+ st.title("Llama 3.2 180M Amharic Chatbot Demo")
9
+ st.write("""
10
+ This chatbot was created using [Llama-3.2-180M-Amharic-Instruct](https://huggingface.co/rasyosef/Llama-3.2-180M-Amharic-Instruct),
11
+ a finetuned version of the 180 million parameter Llama 3.2 Amharic transformer model.
12
+ """)
13
+
14
+ # Load the tokenizer and model
15
+ @st.cache_resource
16
+ def load_model():
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
18
+ model = AutoModelForCausalLM.from_pretrained(model_id)
19
+ llama_pipeline = pipeline(
20
+ "text-generation",
21
+ model=model,
22
+ tokenizer=tokenizer,
23
+ pad_token_id=tokenizer.pad_token_id,
24
+ eos_token_id=tokenizer.eos_token_id
25
+ )
26
+ return tokenizer, llama_pipeline
27
+
28
+ tokenizer, llama_pipeline = load_model()
29
+
30
+ # Generate text
31
+ def generate_response(prompt, chat_history, max_new_tokens):
32
+ history = []
33
+
34
+ # Build chat history
35
+ for sent, received in chat_history:
36
+ history.append({"role": "user", "content": sent})
37
+ history.append({"role": "assistant", "content": received})
38
+
39
+ history.append({"role": "user", "content": prompt})
40
+
41
+ if len(tokenizer.apply_chat_template(history)) > 512:
42
+ return "Chat history is too long."
43
+ else:
44
+ streamer = TextIteratorStreamer(
45
+ tokenizer=tokenizer,
46
+ skip_prompt=True,
47
+ skip_special_tokens=True,
48
+ timeout=300.0
49
+ )
50
+ thread = Thread(target=llama_pipeline, kwargs={
51
+ "text_inputs": history,
52
+ "max_new_tokens": max_new_tokens,
53
+ "repetition_penalty": 1.15,
54
+ "streamer": streamer
55
+ })
56
+ thread.start()
57
+
58
+ generated_text = ""
59
+ for word in streamer:
60
+ generated_text += word
61
+ response = generated_text.strip()
62
+ yield response
63
+
64
+ # Streamlit Input and Chat Interface
65
+ st.sidebar.header("Chatbot Configuration")
66
+ max_tokens = st.sidebar.slider("Maximum new tokens", min_value=8, max_value=256, value=64, help="Larger values result in longer responses.")
67
+
68
+ st.subheader("Chat with the Amharic Chatbot")
69
+ chat_history = st.session_state.get("chat_history", [])
70
+
71
+ # User Input
72
+ user_input = st.text_input("Your message:", placeholder="Type your message here...")
73
+
74
+ if st.button("Send"):
75
+ if user_input:
76
+ st.session_state.chat_history = st.session_state.get("chat_history", [])
77
+ st.session_state.chat_history.append((user_input, ""))
78
+ responses = generate_response(user_input, st.session_state.chat_history, max_tokens)
79
+
80
+ # Stream output
81
+ with st.spinner("Generating response..."):
82
+ final_response = ""
83
+ for response in responses:
84
+ final_response = response
85
+ st.session_state.chat_history[-1] = (user_input, final_response)
86
+ st.experimental_rerun()
87
+
88
+ # Display Chat History
89
+ if "chat_history" in st.session_state:
90
+ for i, (user_msg, bot_response) in enumerate(st.session_state.chat_history):
91
+ st.write(f"**User {i+1}:** {user_msg}")
92
+ st.write(f"**Bot:** {bot_response}")
93
+