File size: 4,825 Bytes
cb1ef25
 
8c95d83
 
cb1ef25
8c95d83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb1ef25
8c95d83
cb1ef25
 
 
 
8c95d83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import streamlit as st
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

class LLAMAChatbot:
    def __init__(self):
        st.title("LLAMA Chatbot")
        self.initialize_model()
        self.initialize_session_state()
        
    def initialize_model(self):
        """Initialize the LLAMA model and tokenizer"""
        try:
            @st.cache_resource
            def load_model():
                tokenizer = AutoTokenizer.from_pretrained("joermd/llma-speedy")
                model = AutoModelForCausalLM.from_pretrained(
                    "joermd/llma-speedy",
                    torch_dtype=torch.float16,
                    device_map="auto"
                )
                return model, tokenizer
            
            self.model, self.tokenizer = load_model()
            st.success("تم تحميل النموذج بنجاح!")
        except Exception as e:
            st.error(f"حدث خطأ أثناء تحميل النموذج: {str(e)}")
            st.stop()
            
    def initialize_session_state(self):
        """Initialize chat history if it doesn't exist"""
        if "messages" not in st.session_state:
            st.session_state.messages = []
            
    def display_chat_history(self):
        """Display all messages from chat history"""
        for message in st.session_state.messages:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])
                
    def add_message(self, role, content):
        """Add a message to the chat history"""
        st.session_state.messages.append({
            "role": role,
            "content": content
        })
        
    def generate_response(self, user_input, max_length=1000):
        """Generate response using LLAMA model"""
        try:
            # Prepare the input context with chat history
            context = ""
            for message in st.session_state.messages[-4:]:  # Use last 4 messages for context
                if message["role"] == "user":
                    context += f"Human: {message['content']}\n"
                else:
                    context += f"Assistant: {message['content']}\n"
            
            context += f"Human: {user_input}\nAssistant:"
            
            # Tokenize input
            inputs = self.tokenizer(context, return_tensors="pt", truncation=True)
            inputs = inputs.to(self.model.device)
            
            # Generate response
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs["input_ids"],
                    max_length=max_length,
                    num_return_sequences=1,
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            # Extract only the assistant's response
            response = response.split("Assistant:")[-1].strip()
            
            return response
            
        except Exception as e:
            return f"عذراً، حدث خطأ أثناء توليد الإجابة: {str(e)}"
            
    def simulate_typing(self, message_placeholder, response):
        """Simulate typing effect for bot response"""
        full_response = ""
        for chunk in response.split():
            full_response += chunk + " "
            time.sleep(0.05)
            message_placeholder.markdown(full_response + "▌")
        message_placeholder.markdown(full_response)
        return full_response
        
    def run(self):
        """Main application loop"""
        # Display existing chat history
        self.display_chat_history()
        
        # Handle user input
        if user_input := st.chat_input("اكتب رسالتك هنا..."):
            # Display and save user message
            self.add_message("user", user_input)
            with st.chat_message("user"):
                st.markdown(user_input)
            
            # Generate and display response
            with st.chat_message("assistant"):
                message_placeholder = st.empty()
                with st.spinner("جاري التفكير..."):
                    assistant_response = self.generate_response(user_input)
                full_response = self.simulate_typing(message_placeholder, assistant_response)
                self.add_message("assistant", full_response)

if __name__ == "__main__":
    # Set page config
    st.set_page_config(
        page_title="LLAMA Chatbot",
        page_icon="🤖",
        layout="wide"
    )
    
    # Initialize and run the chatbot
    chatbot = LLAMAChatbot()
    chatbot.run()