File size: 3,270 Bytes
0c03e17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline
from threading import Thread

# Model Initialization
model_id = "rasyosef/Llama-3.2-180M-Amharic-Instruct"

st.title("Llama 3.2 180M Amharic Chatbot Demo")
st.write("""
This chatbot was created using [Llama-3.2-180M-Amharic-Instruct](https://huggingface.co/rasyosef/Llama-3.2-180M-Amharic-Instruct), 
a finetuned version of the 180 million parameter Llama 3.2 Amharic transformer model.
""")

# Load the tokenizer and model
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id)
    llama_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
    return tokenizer, llama_pipeline

tokenizer, llama_pipeline = load_model()

# Generate text
def generate_response(prompt, chat_history, max_new_tokens):
    history = []

    # Build chat history
    for sent, received in chat_history:
        history.append({"role": "user", "content": sent})
        history.append({"role": "assistant", "content": received})

    history.append({"role": "user", "content": prompt})

    if len(tokenizer.apply_chat_template(history)) > 512:
        return "Chat history is too long."
    else:
        streamer = TextIteratorStreamer(
            tokenizer=tokenizer,
            skip_prompt=True,
            skip_special_tokens=True,
            timeout=300.0
        )
        thread = Thread(target=llama_pipeline, kwargs={
            "text_inputs": history,
            "max_new_tokens": max_new_tokens,
            "repetition_penalty": 1.15,
            "streamer": streamer
        })
        thread.start()

        generated_text = ""
        for word in streamer:
            generated_text += word
            response = generated_text.strip()
            yield response

# Streamlit Input and Chat Interface
st.sidebar.header("Chatbot Configuration")
max_tokens = st.sidebar.slider("Maximum new tokens", min_value=8, max_value=256, value=64, help="Larger values result in longer responses.")

st.subheader("Chat with the Amharic Chatbot")
chat_history = st.session_state.get("chat_history", [])

# User Input
user_input = st.text_input("Your message:", placeholder="Type your message here...")

if st.button("Send"):
    if user_input:
        st.session_state.chat_history = st.session_state.get("chat_history", [])
        st.session_state.chat_history.append((user_input, ""))
        responses = generate_response(user_input, st.session_state.chat_history, max_tokens)

        # Stream output
        with st.spinner("Generating response..."):
            final_response = ""
            for response in responses:
                final_response = response
                st.session_state.chat_history[-1] = (user_input, final_response)
                st.experimental_rerun()

# Display Chat History
if "chat_history" in st.session_state:
    for i, (user_msg, bot_response) in enumerate(st.session_state.chat_history):
        st.write(f"**User {i+1}:** {user_msg}")
        st.write(f"**Bot:** {bot_response}")