File size: 7,312 Bytes
bd9c9bf
6f5fcd8
bd9c9bf
 
 
 
6f5fcd8
c61718b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcc34fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61718b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcc34fe
 
 
 
 
 
 
 
 
 
 
 
 
 
c61718b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcc34fe
 
 
 
 
 
c61718b
 
 
 
 
 
 
 
 
 
9d8ad6a
c61718b
 
 
975d8aa
 
c61718b
 
0dc1b89
c61718b
 
 
 
 
6f5fcd8
 
c61718b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import gradio as gr
from llama_cpp import Llama
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.prompts import PromptTemplate

class RAGInterface:
    def __init__(self):
        # Initialize embedding model
        self.embeddings = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-MiniLM-L6-v2",
            model_kwargs={'device': 'cpu'},
            encode_kwargs={'normalize_embeddings': True}
        )
        
        # Load vector store
        persist_directory = os.path.join(os.path.dirname(__file__), 'mydb')
        self.vectorstore = Chroma(
            persist_directory=persist_directory,
            embedding_function=self.embeddings
        )
        
        # Model configurations
        self.model_configs = {
            "Llama 3.2 3B (Fast, Less Accurate)": {
                "repo_id": "bartowski/Llama-3.2-3B-Instruct-GGUF",
                "filename": "Llama-3.2-3B-Instruct-Q6_K.gguf",
            },
            "Llama 3.1 8B (Slower, More Accurate)": {
                "repo_id": "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
                "filename": "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
            }
        }
        
        # Initialize with default model
        self.current_model = "Llama 3.1 8B (Slower, More Accurate)"
        self.load_model(self.current_model)
        
        # Define RAG prompt template
        self.template = """Answer the question based only on the following context:
        {context}
        
        Question: {question}
        
        Answer the question in a clear way. If you cannot find the answer in the context, 
        just say "I don't have enough information to answer this question."
        
        Make sure to:
        1. Only use information from the provided context
        2. If you're unsure, acknowledge it
        """
        self.prompt = PromptTemplate.from_template(self.template)

    def load_model(self, model_name):
        config = self.model_configs[model_name]
        self.llm = Llama.from_pretrained(
            repo_id=config["repo_id"],
            filename=config["filename"],
            n_ctx=2048
        )
        self.current_model = model_name

    def respond(self, message, history, system_message, model_choice, temperature, max_tokens=2048):
        # Load new model if different from current
        if model_choice != self.current_model:
            self.load_model(model_choice)
        
        # Build messages list
        messages = [{"role": "system", "content": system_message}]
        for user_msg, assistant_msg in history:
            if user_msg:
                messages.append({"role": "user", "content": user_msg})
            if assistant_msg:
                messages.append({"role": "assistant", "content": assistant_msg})

        # Search vector store
        retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
        docs = retriever.get_relevant_documents(message)
        context = "\n\n".join([doc.page_content for doc in docs])

        # Format prompt and add to messages
        final_prompt = self.prompt.format(context=context, question=message)
        messages.append({"role": "user", "content": final_prompt})

        # Generate response
        response = self.llm.create_chat_completion(
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
        )

        return response['choices'][0]['message']['content']

    def create_interface(self):
        # Custom CSS for better styling
        custom_css = """
        <style>
            /* Global Styles */
            body, #root {
                font-family: Helvetica, Arial, sans-serif;
                background-color: #1a1a1a;
                color: #fafafa;
            }
            
            /* Header Styles */
            .app-header {
                background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
                padding: 24px;
                border-radius: 8px;
                margin-bottom: 24px;
                text-align: center;
            }
            
            .app-title {
                font-size: 36px;
                margin: 0;
                color: #fafafa;
            }
            
            .app-subtitle {
                font-size: 18px;
                margin: 8px 0;
                color: #fafafa;
                opacity: 0.8;
            }
            
            /* Chat Container */
            .chat-container {
                background-color: #2a2a2a;
                border-radius: 8px;
                padding: 20px;
                margin-bottom: 20px;
            }
            
            /* Control Panel */
            .control-panel {
                background-color: #333;
                padding: 16px;
                border-radius: 8px;
                margin-top: 16px;
            }
            
            /* Gradio Component Overrides */
            .gr-button {
                background-color: #4a4a4a;
                color: #fff;
                border: none;
                border-radius: 4px;
                padding: 8px 16px;
                transition: background-color 0.3s;
            }
            
            .gr-button:hover {
                background-color: #5a5a5a;
            }
            
            .gr-input, .gr-dropdown {
                background-color: #3a3a3a;
                color: #fff;
                border: 1px solid #4a4a4a;
                border-radius: 4px;
                padding: 8px;
            }
        </style>
        """

        # Header HTML
        header_html = f"""
        <div class="app-header">
            <h1 class="app-title">Document-Based Question Answering</h1>
            <h2 class="app-subtitle">Powered by Llama and RAG</h2>
        </div>
        {custom_css}
        """

        # Create Gradio interface
        demo = gr.ChatInterface(
            fn=self.respond,
            additional_inputs=[
                gr.Textbox(
                    value="You are a friendly chatbot.",
                    label="System Message",
                    elem_classes="control-panel"
                ),
                gr.Dropdown(
                    choices=list(self.model_configs.keys()),
                    value=self.current_model,
                    label="Select Model",
                    elem_classes="control-panel"
                ),
                gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.7,
                    step=0.1,
                    label="Temperature",
                    elem_classes="control-panel"
                ),
            ],
            title="",  # Title is handled in custom HTML
            description="Ask questions about Computers and get AI-powered answers.",
            theme=gr.themes.Default(),
        )

        # Wrap the interface with custom CSS
        with gr.Blocks(css=custom_css) as wrapper:
            gr.HTML(header_html)
            demo.render()
        return wrapper

def main():
    interface = RAGInterface()
    demo = interface.create_interface()
    demo.launch(debug=True)

if __name__ == "__main__":
    main()