eikarna commited on
Commit
d9760ae
·
1 Parent(s): 853734d

Fix: File Upload Session

Browse files
Files changed (1) hide show
  1. app.py +126 -243
app.py CHANGED
@@ -1,264 +1,147 @@
1
  import streamlit as st
2
  import requests
3
  import logging
4
- import time
5
- from typing import Dict, Any, Optional, List
6
- import os
7
- from PIL import Image
8
- import pytesseract
9
- import fitz # PyMuPDF
10
- from io import BytesIO
11
- import hashlib
12
- from sentence_transformers import SentenceTransformer
13
- import numpy as np
14
- from pathlib import Path
15
- import pickle
16
- import tempfile
17
 
18
  # Configure logging
19
- logging.basicConfig(
20
- level=logging.INFO,
21
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22
- )
23
  logger = logging.getLogger(__name__)
24
 
25
- # Initialize SBERT model for embeddings
26
- @st.cache_resource
27
- def load_embedding_model():
28
- return SentenceTransformer('all-MiniLM-L6-v2')
29
-
30
- # Modified Vector Store Class
31
- class SimpleVectorStore:
32
- def __init__(self):
33
- self.documents = []
34
- self.embeddings = []
35
 
36
- def add_document(self, text: str, embedding: np.ndarray):
37
- self.documents.append(text)
38
- self.embeddings.append(embedding)
39
-
40
- def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[str]:
41
- if not self.embeddings:
42
- return []
43
-
44
- similarities = np.dot(self.embeddings, query_embedding)
45
- top_indices = np.argsort(similarities)[-top_k:][::-1]
46
- return [self.documents[i] for i in top_indices]
47
-
48
- # Document processing functions
49
- def process_text(text: str) -> List[str]:
50
- """Split text into chunks."""
51
- # Simple splitting by sentences (can be improved with better chunking)
52
- chunks = text.split('. ')
53
- return [chunk + '.' for chunk in chunks if chunk]
54
-
55
- def process_image(image) -> str:
56
- """Extract text from image using OCR."""
57
- try:
58
- text = pytesseract.image_to_string(image)
59
- return text
60
- except Exception as e:
61
- logger.error(f"Error processing image: {str(e)}")
62
- return ""
63
-
64
- def process_pdf(pdf_file) -> str:
65
- """Extract text from PDF."""
66
- try:
67
- with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
68
- tmp_file.write(pdf_file.read())
69
- tmp_file.flush()
70
-
71
- doc = fitz.open(tmp_file.name)
72
- text = ""
73
- for page in doc:
74
- text += page.get_text()
75
- doc.close()
76
- os.unlink(tmp_file.name)
77
- return text
78
- except Exception as e:
79
- logger.error(f"Error processing PDF: {str(e)}")
80
- return ""
81
-
82
- # Initialize session state
83
- if "messages" not in st.session_state:
84
- st.session_state.messages = []
85
- if "request_timestamps" not in st.session_state:
86
- st.session_state.request_timestamps = []
87
- if "vector_store" not in st.session_state:
88
- st.session_state.vector_store = SimpleVectorStore()
89
-
90
- # Rate limiting configuration
91
- RATE_LIMIT_PERIOD = 60
92
- MAX_REQUESTS_PER_PERIOD = 30
93
-
94
- def check_rate_limit() -> bool:
95
- """Check if we're within rate limits."""
96
- current_time = time.time()
97
- st.session_state.request_timestamps = [
98
- ts for ts in st.session_state.request_timestamps
99
- if current_time - ts < RATE_LIMIT_PERIOD
100
- ]
101
-
102
- if len(st.session_state.request_timestamps) >= MAX_REQUESTS_PER_PERIOD:
103
- return False
104
-
105
- st.session_state.request_timestamps.append(current_time)
106
- return True
107
-
108
- def query(payload: Dict[str, Any], api_url: str) -> Optional[Dict[str, Any]]:
109
- """Query the Hugging Face API with error handling and rate limiting."""
110
- if not check_rate_limit():
111
- raise Exception(f"Rate limit exceeded. Please wait {RATE_LIMIT_PERIOD} seconds.")
112
-
113
- try:
114
- headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}
115
- response = requests.post(api_url, headers=headers, json=payload, timeout=30)
116
-
117
- if response.status_code == 429:
118
- raise Exception("Too many requests. Please try again later.")
119
-
120
- response.raise_for_status()
121
- print(response.request.url)
122
- print(response.request.headers)
123
- print(response.request.body)
124
- print(response)
125
- return response.json()
126
- except requests.exceptions.JSONDecodeError as e:
127
- logger.error(f"API request failed: {str(e)}")
128
- raise
129
-
130
- # Enhanced response validation
131
- def process_response(response: Dict[str, Any]) -> str:
132
- if not isinstance(response, list) or not response:
133
- raise ValueError("Invalid response format")
134
-
135
- if 'generated_text' not in response[0]:
136
- raise ValueError("Unexpected response structure")
137
-
138
- text = response[0]['generated_text'].strip()
139
 
140
  # Page configuration
141
  st.set_page_config(
142
- page_title="RAG-Enabled DeepSeek Chatbot",
143
  page_icon="🤖",
144
- layout="wide"
145
  )
146
 
147
- # Sidebar configuration
148
- with st.sidebar:
149
- st.header("Model Configuration")
150
- st.markdown("[Get HuggingFace Token](https://huggingface.co/settings/tokens)")
151
-
152
- model_options = [
153
- "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
154
- ]
155
- selected_model = st.selectbox("Select Model", model_options, index=0)
156
-
157
- system_message = st.text_area(
158
- "System Message",
159
- value="You are a friendly chatbot with RAG capabilities. Use the provided context to answer questions accurately. If the context doesn't contain relevant information, say so.",
160
- height=100
161
- )
162
-
163
- max_tokens = st.slider("Max Tokens", 10, 4000, 100)
164
- temperature = st.slider("Temperature", 0.1, 4.0, 0.3)
165
- top_p = st.slider("Top-p", 0.1, 1.0, 0.6)
166
-
167
- # File upload section
168
- st.header("Upload Knowledge Base")
169
- uploaded_files = st.file_uploader(
170
- "Upload files (PDF, Images, Text)",
171
- type=['pdf', 'png', 'jpg', 'jpeg', 'txt'],
172
- accept_multiple_files=True
173
- )
 
 
 
 
174
 
175
- # Process uploaded files
176
- if uploaded_files:
177
- embedding_model = load_embedding_model()
178
 
179
- for file in uploaded_files:
180
- try:
181
- if file.type == "application/pdf":
182
- text = process_pdf(file)
183
- elif file.type.startswith("image/"):
184
- image = Image.open(file)
185
- text = process_image(image)
186
- else: # text files
187
- text = file.getvalue().decode()
188
-
189
- chunks = process_text(text)
190
- for chunk in chunks:
191
- embedding = embedding_model.encode(chunk)
192
- st.session_state.vector_store.add_document(chunk, embedding)
193
-
194
- st.sidebar.success(f"Successfully processed {file.name}")
195
- except Exception as e:
196
- st.sidebar.error(f"Error processing {file.name}: {str(e)}")
197
-
198
- # Main chat interface
199
- st.title("🤖 RAG-Enabled DeepSeek Chatbot")
200
- st.caption("Upload documents in the sidebar to enhance the chatbot's knowledge")
201
-
202
- # Display chat history
203
- for message in st.session_state.messages:
204
- with st.chat_message(message["role"]):
205
- st.markdown(message["content"])
206
-
207
- # Handle user input
208
- if prompt := st.chat_input("Type your message..."):
209
- # Display user message
210
- st.session_state.messages.append({"role": "user", "content": prompt})
211
- with st.chat_message("user"):
212
- st.markdown(prompt)
213
-
214
  try:
215
- with st.spinner("Generating response..."):
216
- embedding_model = load_embedding_model()
217
- query_embedding = embedding_model.encode(prompt)
218
- relevant_contexts = st.session_state.vector_store.search(query_embedding)
219
-
220
- # Dynamic context handling
221
- context_text = "\n".join(relevant_contexts) if relevant_contexts else ""
222
- system_msg = (
223
- f"{system_message} Use the provided context to answer accurately."
224
- if context_text
225
- else system_message
226
- )
227
-
228
- # Format for DeepSeek model
229
- full_prompt = f"""<|beginofutterance|>System: {system_msg}
230
- {context_text if context_text else ''}
231
- <|endofutterance|>
232
- <|beginofutterance|>User: {prompt}<|endofutterance|>
233
- <|beginofutterance|>Assistant:"""
 
 
 
 
 
234
 
235
- payload = {
236
- "inputs": full_prompt,
237
- "parameters": {
238
- "max_new_tokens": max_tokens,
239
- "temperature": temperature,
240
- "top_p": top_p,
241
- "return_full_text": False
 
 
 
 
 
 
 
 
242
  }
243
- }
244
 
245
- api_url = f"https://api-inference.huggingface.co/models/{selected_model}"
246
-
247
- # Get and process response
248
- output = query(payload, api_url)
249
- if output:
250
- response_text = process_response(output)
251
-
252
- # Display assistant response
253
- with st.chat_message("assistant"):
254
- st.markdown(response_text)
255
-
256
- # Update chat history
257
- st.session_state.messages.append({
258
- "role": "assistant",
259
- "content": response_text
260
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- except Exception as e:
263
- logger.error(f"Error: {str(e)}", exc_info=True)
264
- st.error(f"Error: {str(e)}")
 
1
  import streamlit as st
2
  import requests
3
  import logging
4
+ from typing import Optional, Dict, Any
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Configure logging
7
+ logging.basicConfig(level=logging.INFO)
 
 
 
8
  logger = logging.getLogger(__name__)
9
 
10
+ # Constants
11
+ DEFAULT_SYSTEM_PROMPT = """You are a friendly Assistant. Provide clear, accurate, and brief answers.
12
+ Keep responses polite, engaging, and to the point. If unsure, politely suggest alternatives."""
 
 
 
 
 
 
 
13
 
14
+ MODEL_OPTIONS = ["deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"]
15
+ API_BASE_URL = "https://api-inference.huggingface.co/models/"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Page configuration
18
  st.set_page_config(
19
+ page_title="DeepSeek-AI R1 (32B)",
20
  page_icon="🤖",
21
+ layout="centered"
22
  )
23
 
24
+ def initialize_session_state():
25
+ """Initialize all session state variables"""
26
+ if "messages" not in st.session_state:
27
+ st.session_state.messages = []
28
+ if "api_failures" not in st.session_state:
29
+ st.session_state.api_failures = 0
30
+
31
+ def configure_sidebar() -> Dict[str, Any]:
32
+ """Create sidebar components and return settings"""
33
+ with st.sidebar:
34
+ st.header("Model Configuration")
35
+ st.markdown("[Get HuggingFace Token](https://huggingface.co/settings/tokens)")
36
+
37
+ return {
38
+ "model": st.selectbox("Select Model", MODEL_OPTIONS, index=0),
39
+ "system_message": st.text_area(
40
+ "System Message",
41
+ value=DEFAULT_SYSTEM_PROMPT,
42
+ height=100
43
+ ),
44
+ "max_tokens": st.slider("Max Tokens", 10, 4000, 100),
45
+ "temperature": st.slider("Temperature", 0.1, 4.0, 0.3),
46
+ "top_p": st.slider("Top-p", 0.1, 1.0, 0.6)
47
+ }
48
+
49
+ def format_deepseek_prompt(system_message: str, user_input: str) -> str:
50
+ """Format the prompt according to DeepSeek's required structure"""
51
+ return f"""<|beginofutterance|>System: {system_message}
52
+ <|endofutterance|>
53
+ <|beginofutterance|>User: {user_input}<|endofutterance|>
54
+ <|beginofutterance|>Assistant:"""
55
 
56
+ def query_hf_api(payload: Dict[str, Any], api_url: str) -> Optional[Dict[str, Any]]:
57
+ """Handle API requests with improved error handling"""
58
+ headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  try:
61
+ response = requests.post(
62
+ api_url,
63
+ headers=headers,
64
+ json=payload,
65
+ timeout=30
66
+ )
67
+ response.raise_for_status()
68
+ return response.json()
69
+ except requests.exceptions.HTTPError as e:
70
+ logger.error(f"HTTP Error: {e.response.status_code} - {e.response.text}")
71
+ st.error(f"API Error: {e.response.status_code} - {e.response.text[:200]}")
72
+ except requests.exceptions.RequestException as e:
73
+ logger.error(f"Request failed: {str(e)}")
74
+ st.error("Connection error. Please check your internet connection.")
75
+ return None
76
+
77
+ def handle_chat_interaction(settings: Dict[str, Any]):
78
+ """Manage chat input/output and API communication"""
79
+ if prompt := st.chat_input("Type your message..."):
80
+ # Add user message to history
81
+ st.session_state.messages.append({"role": "user", "content": prompt})
82
+
83
+ with st.chat_message("user"):
84
+ st.markdown(prompt)
85
 
86
+ try:
87
+ with st.spinner("Generating response..."):
88
+ # Format prompt according to model requirements
89
+ full_prompt = format_deepseek_prompt(
90
+ system_message=settings["system_message"],
91
+ user_input=prompt
92
+ )
93
+
94
+ payload = {
95
+ "inputs": full_prompt,
96
+ "parameters": {
97
+ "max_new_tokens": settings["max_tokens"],
98
+ "temperature": settings["temperature"],
99
+ "top_p": settings["top_p"]
100
+ }
101
  }
 
102
 
103
+ api_url = f"{API_BASE_URL}{settings['model']}"
104
+ output = query_hf_api(payload, api_url)
105
+
106
+ if output and isinstance(output, list):
107
+ if 'generated_text' in output[0]:
108
+ response_text = output[0]['generated_text'].strip()
109
+ # Remove any remaining special tokens
110
+ response_text = response_text.replace("<|endofutterance|>", "").strip()
111
+
112
+ # Display and store response
113
+ with st.chat_message("assistant"):
114
+ st.markdown(response_text)
115
+ st.session_state.messages.append(
116
+ {"role": "assistant", "content": response_text}
117
+ )
118
+ return
119
+
120
+ # Handle failed responses
121
+ st.session_state.api_failures += 1
122
+ if st.session_state.api_failures > 2:
123
+ st.error("Persistent API failures. Please check your API token and model selection.")
124
+
125
+ except Exception as e:
126
+ logger.error(f"Unexpected error: {str(e)}", exc_info=True)
127
+ st.error("An unexpected error occurred. Please try again.")
128
+
129
+ def display_chat_history():
130
+ """Render chat message history"""
131
+ for message in st.session_state.messages:
132
+ with st.chat_message(message["role"]):
133
+ st.markdown(message["content"])
134
+
135
+ def main():
136
+ """Main application flow"""
137
+ initialize_session_state()
138
+ settings = configure_sidebar()
139
+
140
+ st.title("🤖 DeepSeek Chatbot")
141
+ st.caption("Powered by Hugging Face Inference API - Configure in sidebar")
142
+
143
+ display_chat_history()
144
+ handle_chat_interaction(settings)
145
 
146
+ if __name__ == "__main__":
147
+ main()