Nixic commited on
Commit
dd36385
·
verified ·
1 Parent(s): 6298a59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -69
app.py CHANGED
@@ -1,28 +1,172 @@
1
  import streamlit as st
2
  import requests
3
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Configure logging
6
- logging.basicConfig(level=logging.INFO)
 
 
 
7
  logger = logging.getLogger(__name__)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Page configuration
10
  st.set_page_config(
11
- page_title="DeepSeek Chatbot",
12
  page_icon="🤖",
13
- layout="centered"
14
  )
15
 
16
- # Initialize session state for chat history
17
- if "messages" not in st.session_state:
18
- st.session_state.messages = []
19
-
20
  # Sidebar configuration
21
  with st.sidebar:
22
  st.header("Model Configuration")
23
  st.markdown("[Get HuggingFace Token](https://huggingface.co/settings/tokens)")
24
 
25
- # Dropdown to select model
26
  model_options = [
27
  "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
28
  ]
@@ -30,58 +174,78 @@ with st.sidebar:
30
 
31
  system_message = st.text_area(
32
  "System Message",
33
- value="You are a friendly chatbot created by ruslanmv.com. Provide clear, accurate, and brief answers. Keep responses polite, engaging, and to the point. If unsure, politely suggest alternatives.",
34
  height=100
35
  )
36
 
37
- max_tokens = st.slider(
38
- "Max Tokens",
39
- 10, 4000, 100
40
- )
41
 
42
- temperature = st.slider(
43
- "Temperature",
44
- 0.1, 4.0, 0.3
 
 
 
45
  )
46
 
47
- top_p = st.slider(
48
- "Top-p",
49
- 0.1, 1.0, 0.6
50
- )
 
 
 
 
 
 
 
 
 
51
 
52
- # Function to query the Hugging Face API
53
- def query(payload, api_url):
54
- headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}
55
- logger.info(f"Sending request to {api_url} with payload: {payload}")
56
- response = requests.post(api_url, headers=headers, json=payload)
57
- logger.info(f"Received response: {response.status_code}, {response.text}")
58
- try:
59
- return response.json()
60
- except requests.exceptions.JSONDecodeError:
61
- logger.error(f"Failed to decode JSON response: {response.text}")
62
- return None
63
 
64
- # Chat interface
65
- st.title("🤖 DeepSeek Chatbot")
66
- st.caption("Powered by Hugging Face Inference API - Configure in sidebar")
 
 
 
 
67
 
68
  # Display chat history
69
  for message in st.session_state.messages:
70
  with st.chat_message(message["role"]):
71
  st.markdown(message["content"])
72
 
73
- # Handle input
74
  if prompt := st.chat_input("Type your message..."):
 
75
  st.session_state.messages.append({"role": "user", "content": prompt})
76
-
77
  with st.chat_message("user"):
78
  st.markdown(prompt)
79
 
80
  try:
81
  with st.spinner("Generating response..."):
82
- # Prepare the payload for the API
83
- # Combine system message and user input into a single prompt
84
- full_prompt = f"{system_message}\n\nUser: {prompt}\nAssistant:"
 
 
 
 
 
 
 
 
 
 
 
 
85
  payload = {
86
  "inputs": full_prompt,
87
  "parameters": {
@@ -92,37 +256,23 @@ if prompt := st.chat_input("Type your message..."):
92
  }
93
  }
94
 
95
- # Dynamically construct the API URL based on the selected model
96
  api_url = f"https://api-inference.huggingface.co/models/{selected_model}"
97
- logger.info(f"Selected model: {selected_model}, API URL: {api_url}")
98
-
99
- # Query the Hugging Face API using the selected model
100
  output = query(payload, api_url)
101
-
102
- # Handle API response
103
- if output is not None and isinstance(output, list) and len(output) > 0:
104
- if 'generated_text' in output[0]:
105
- # Extract the assistant's response
106
- assistant_response = output[0]['generated_text'].strip()
107
-
108
- # Check for and remove duplicate responses
109
- responses = assistant_response.split("\n</think>\n")
110
- unique_response = responses[0].strip()
111
-
112
- logger.info(f"Generated response: {unique_response}")
113
-
114
- # Append response to chat only once
115
- with st.chat_message("assistant"):
116
- st.markdown(unique_response)
117
-
118
- st.session_state.messages.append({"role": "assistant", "content": unique_response})
119
- else:
120
- logger.error(f"Unexpected API response structure: {output}")
121
- st.error("Error: Unexpected response from the model. Please try again.")
122
- else:
123
- logger.error(f"Empty or invalid API response: {output}")
124
- st.error("Error: Unable to generate a response. Please check the model and try again.")
125
 
126
  except Exception as e:
127
- logger.error(f"Application Error: {str(e)}", exc_info=True)
128
- st.error(f"Application Error: {str(e)}")
 
1
  import streamlit as st
2
  import requests
3
  import logging
4
+ import time
5
+ from typing import Dict, Any, Optional, List
6
+ import os
7
+ from PIL import Image
8
+ import pytesseract
9
+ import fitz # PyMuPDF
10
+ from io import BytesIO
11
+ import hashlib
12
+ from sentence_transformers import SentenceTransformer
13
+ import numpy as np
14
+ from pathlib import Path
15
+ import pickle
16
+ import tempfile
17
 
18
  # Configure logging
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22
+ )
23
  logger = logging.getLogger(__name__)
24
 
25
+ # Initialize SBERT model for embeddings
26
+ @st.cache_resource
27
+ def load_embedding_model():
28
+ return SentenceTransformer('all-MiniLM-L6-v2')
29
+
30
+ # Vector store class
31
+ class SimpleVectorStore:
32
+ def __init__(self, file_path: str = "vector_store.pkl"):
33
+ self.file_path = file_path
34
+ self.documents = []
35
+ self.embeddings = []
36
+ self.load()
37
+
38
+ def load(self):
39
+ if os.path.exists(self.file_path):
40
+ with open(self.file_path, 'rb') as f:
41
+ data = pickle.load(f)
42
+ self.documents = data['documents']
43
+ self.embeddings = data['embeddings']
44
+
45
+ def save(self):
46
+ with open(self.file_path, 'wb') as f:
47
+ pickle.dump({
48
+ 'documents': self.documents,
49
+ 'embeddings': self.embeddings
50
+ }, f)
51
+
52
+ def add_document(self, text: str, embedding: np.ndarray):
53
+ self.documents.append(text)
54
+ self.embeddings.append(embedding)
55
+ self.save()
56
+
57
+ def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[str]:
58
+ if not self.embeddings:
59
+ return []
60
+
61
+ similarities = np.dot(self.embeddings, query_embedding)
62
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
63
+ return [self.documents[i] for i in top_indices]
64
+
65
+ # Document processing functions
66
+ def process_text(text: str) -> List[str]:
67
+ """Split text into chunks."""
68
+ # Simple splitting by sentences (can be improved with better chunking)
69
+ chunks = text.split('. ')
70
+ return [chunk + '.' for chunk in chunks if chunk]
71
+
72
+ def process_image(image) -> str:
73
+ """Extract text from image using OCR."""
74
+ try:
75
+ text = pytesseract.image_to_string(image)
76
+ return text
77
+ except Exception as e:
78
+ logger.error(f"Error processing image: {str(e)}")
79
+ return ""
80
+
81
+ def process_pdf(pdf_file) -> str:
82
+ """Extract text from PDF."""
83
+ try:
84
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
85
+ tmp_file.write(pdf_file.read())
86
+ tmp_file.flush()
87
+
88
+ doc = fitz.open(tmp_file.name)
89
+ text = ""
90
+ for page in doc:
91
+ text += page.get_text()
92
+ doc.close()
93
+ os.unlink(tmp_file.name)
94
+ return text
95
+ except Exception as e:
96
+ logger.error(f"Error processing PDF: {str(e)}")
97
+ return ""
98
+
99
+ # Initialize session state
100
+ if "messages" not in st.session_state:
101
+ st.session_state.messages = []
102
+ if "request_timestamps" not in st.session_state:
103
+ st.session_state.request_timestamps = []
104
+ if "vector_store" not in st.session_state:
105
+ st.session_state.vector_store = SimpleVectorStore()
106
+
107
+ # Rate limiting configuration
108
+ RATE_LIMIT_PERIOD = 60
109
+ MAX_REQUESTS_PER_PERIOD = 30
110
+
111
+ def check_rate_limit() -> bool:
112
+ """Check if we're within rate limits."""
113
+ current_time = time.time()
114
+ st.session_state.request_timestamps = [
115
+ ts for ts in st.session_state.request_timestamps
116
+ if current_time - ts < RATE_LIMIT_PERIOD
117
+ ]
118
+
119
+ if len(st.session_state.request_timestamps) >= MAX_REQUESTS_PER_PERIOD:
120
+ return False
121
+
122
+ st.session_state.request_timestamps.append(current_time)
123
+ return True
124
+
125
+ def query(payload: Dict[str, Any], api_url: str) -> Optional[Dict[str, Any]]:
126
+ """Query the Hugging Face API with error handling and rate limiting."""
127
+ if not check_rate_limit():
128
+ raise Exception(f"Rate limit exceeded. Please wait {RATE_LIMIT_PERIOD} seconds.")
129
+
130
+ try:
131
+ headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}
132
+ response = requests.post(api_url, headers=headers, json=payload, timeout=30)
133
+
134
+ if response.status_code == 429:
135
+ raise Exception("Too many requests. Please try again later.")
136
+
137
+ response.raise_for_status()
138
+ return response.json()
139
+ except Exception as e:
140
+ logger.error(f"API request failed: {str(e)}")
141
+ raise
142
+
143
+ def process_response(response: Dict[str, Any]) -> str:
144
+ """Process and clean up the model's response."""
145
+ if not isinstance(response, list) or not response:
146
+ raise ValueError("Invalid response format")
147
+
148
+ text = response[0]['generated_text'].strip()
149
+ cleanup_patterns = [
150
+ "Assistant:", "AI:", "</think>", "<think>",
151
+ "\n\nHuman:", "\n\nUser:"
152
+ ]
153
+ for pattern in cleanup_patterns:
154
+ text = text.replace(pattern, "").strip()
155
+
156
+ return text
157
+
158
  # Page configuration
159
  st.set_page_config(
160
+ page_title="RAG-Enabled DeepSeek Chatbot",
161
  page_icon="🤖",
162
+ layout="wide"
163
  )
164
 
 
 
 
 
165
  # Sidebar configuration
166
  with st.sidebar:
167
  st.header("Model Configuration")
168
  st.markdown("[Get HuggingFace Token](https://huggingface.co/settings/tokens)")
169
 
 
170
  model_options = [
171
  "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
172
  ]
 
174
 
175
  system_message = st.text_area(
176
  "System Message",
177
+ value="You are a friendly chatbot with RAG capabilities. Use the provided context to answer questions accurately. If the context doesn't contain relevant information, say so.",
178
  height=100
179
  )
180
 
181
+ max_tokens = st.slider("Max Tokens", 10, 4000, 100)
182
+ temperature = st.slider("Temperature", 0.1, 4.0, 0.3)
183
+ top_p = st.slider("Top-p", 0.1, 1.0, 0.6)
 
184
 
185
+ # File upload section
186
+ st.header("Upload Knowledge Base")
187
+ uploaded_files = st.file_uploader(
188
+ "Upload files (PDF, Images, Text)",
189
+ type=['pdf', 'png', 'jpg', 'jpeg', 'txt'],
190
+ accept_multiple_files=True
191
  )
192
 
193
+ # Process uploaded files
194
+ if uploaded_files:
195
+ embedding_model = load_embedding_model()
196
+
197
+ for file in uploaded_files:
198
+ try:
199
+ if file.type == "application/pdf":
200
+ text = process_pdf(file)
201
+ elif file.type.startswith("image/"):
202
+ image = Image.open(file)
203
+ text = process_image(image)
204
+ else: # text files
205
+ text = file.getvalue().decode()
206
 
207
+ chunks = process_text(text)
208
+ for chunk in chunks:
209
+ embedding = embedding_model.encode(chunk)
210
+ st.session_state.vector_store.add_document(chunk, embedding)
 
 
 
 
 
 
 
211
 
212
+ st.sidebar.success(f"Successfully processed {file.name}")
213
+ except Exception as e:
214
+ st.sidebar.error(f"Error processing {file.name}: {str(e)}")
215
+
216
+ # Main chat interface
217
+ st.title("🤖 RAG-Enabled DeepSeek Chatbot")
218
+ st.caption("Upload documents in the sidebar to enhance the chatbot's knowledge")
219
 
220
  # Display chat history
221
  for message in st.session_state.messages:
222
  with st.chat_message(message["role"]):
223
  st.markdown(message["content"])
224
 
225
+ # Handle user input
226
  if prompt := st.chat_input("Type your message..."):
227
+ # Display user message
228
  st.session_state.messages.append({"role": "user", "content": prompt})
 
229
  with st.chat_message("user"):
230
  st.markdown(prompt)
231
 
232
  try:
233
  with st.spinner("Generating response..."):
234
+ # Get relevant context from vector store
235
+ embedding_model = load_embedding_model()
236
+ query_embedding = embedding_model.encode(prompt)
237
+ relevant_contexts = st.session_state.vector_store.search(query_embedding)
238
+
239
+ # Prepare context-enhanced prompt
240
+ context_text = "\n".join(relevant_contexts)
241
+ full_prompt = f"""Context information:
242
+ {context_text}
243
+
244
+ System: {system_message}
245
+
246
+ User: {prompt}
247
+ Assistant: Let me help you based on the provided context."""
248
+
249
  payload = {
250
  "inputs": full_prompt,
251
  "parameters": {
 
256
  }
257
  }
258
 
 
259
  api_url = f"https://api-inference.huggingface.co/models/{selected_model}"
260
+
261
+ # Get and process response
 
262
  output = query(payload, api_url)
263
+ if output:
264
+ response_text = process_response(output)
265
+
266
+ # Display assistant response
267
+ with st.chat_message("assistant"):
268
+ st.markdown(response_text)
269
+
270
+ # Update chat history
271
+ st.session_state.messages.append({
272
+ "role": "assistant",
273
+ "content": response_text
274
+ })
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  except Exception as e:
277
+ logger.error(f"Error: {str(e)}", exc_info=True)
278
+ st.error(f"Error: {str(e)}")