Chris4K commited on
Commit
ca6ed76
·
verified ·
1 Parent(s): 4c2be05

Update services/chat_service.py

Browse files
Files changed (1) hide show
  1. services/chat_service.py +97 -114
services/chat_service.py CHANGED
@@ -1,8 +1,10 @@
1
- # services/chat_service.py
 
2
  from typing import List, Dict, Any, Optional, Tuple
3
  from datetime import datetime
4
  import logging
5
  from config.config import settings
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
@@ -29,7 +31,6 @@ class ConversationManager:
29
  'context': context
30
  })
31
 
32
- # Trim history if needed
33
  if len(self.conversations[session_id]) > self.max_history:
34
  self.conversations[session_id] = self.conversations[session_id][-self.max_history:]
35
 
@@ -41,8 +42,6 @@ class ConversationManager:
41
  del self.conversations[session_id]
42
 
43
  class ChatService:
44
- """Main chat service that coordinates responses"""
45
- """Main chat service that coordinates responses"""
46
  def __init__(
47
  self,
48
  model_service,
@@ -57,6 +56,33 @@ class ChatService:
57
  self.faq_service = faq_service
58
  self.conversation_manager = ConversationManager()
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def construct_system_prompt(self, context: str) -> str:
61
  """Constructs the system message."""
62
  return (
@@ -73,72 +99,68 @@ class ChatService:
73
  chat_history: List[Tuple[str, str]],
74
  max_history_turns: int = 1
75
  ) -> str:
76
- """Constructs the full prompt."""
77
- # System message
78
- system_message = self.construct_system_prompt(context)
79
-
80
- # Start with system message
81
- prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>"
82
-
83
- # Add chat history (limit to last `max_history_turns` interactions)
84
- for user_msg, assistant_msg in chat_history[-max_history_turns:]:
85
- prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>"
86
- prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>"
87
-
88
- # Add the current user input
89
- prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
90
- prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
91
-
92
- return prompt
93
 
94
  def build_context(
95
- self,
96
- search_results: Dict[str, List[Dict[str, Any]]],
97
- chat_history: List[Dict[str, Any]]
98
- ) -> str:
99
- """Build context for the model from search results and chat history"""
100
- context_parts = []
101
-
102
- # Add relevant products
103
- if search_results.get('products'):
104
- products = search_results['products'][:2] # Limit to top 2 products
105
- for product in products:
106
- context_parts.append(
107
- f"Produkt: {product['Name']}\n"
108
- f"Beschreibung: {product['Description']}\n"
109
- f"Preis: {product['Price']}€\n"
110
- f"Kategorie: {product['ProductCategory']}"
111
- )
112
-
113
- # Add relevant PDF content
114
- if search_results.get('documents'):
115
- docs = search_results['documents'][:2]
116
- for doc in docs:
117
- context_parts.append(
118
- f"Aus Dokument '{doc['source']}' (Seite {doc['page']}):\n"
119
- f"{doc['text']}"
120
- )
121
-
122
- # Add relevant FAQs
123
- if search_results.get('faqs'):
124
- faqs = search_results['faqs'][:2]
125
- for faq in faqs:
126
- context_parts.append(
127
- f"FAQ:\n"
128
- f"Frage: {faq['question']}\n"
129
- f"Antwort: {faq['answer']}"
130
- )
131
-
132
- # Add recent chat history
133
- if chat_history:
134
- recent_history = chat_history[-3:] # Last 3 interactions
135
- history_text = "\n".join(
136
- f"User: {h['user_input']}\nAssistant: {h['response']}"
137
- for h in recent_history
138
  )
139
- context_parts.append(f"Letzte Interaktionen:\n{history_text}")
140
- print("\n\n".join(context_parts))
141
- return "\n\n".join(context_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  async def chat(
144
  self,
@@ -148,29 +170,21 @@ class ChatService:
148
  ) -> Tuple[str, List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]]]:
149
  """Main chat method that coordinates the entire conversation flow."""
150
  try:
151
- # Ensure session_id is a string
152
  if not isinstance(session_id, str):
153
  session_id = str(session_id)
154
 
155
- # Get chat history
156
  chat_history_raw = self.conversation_manager.get_history(session_id)
157
  chat_history = [
158
  (entry['user_input'], entry['response']) for entry in chat_history_raw
159
  ]
160
 
161
- # Search all sources
162
- search_results = await self.search_all_sources(user_input)
163
  print(search_results)
164
- # Build context
165
  context = self.build_context(search_results, chat_history_raw)
166
-
167
- # Construct the prompt
168
  prompt = self.construct_prompt(user_input, context, chat_history)
 
169
 
170
- # Generate response
171
- response = await self.generate_response(prompt, max_length)
172
-
173
- # Store interaction
174
  self.conversation_manager.add_interaction(
175
  session_id,
176
  user_input,
@@ -178,9 +192,9 @@ class ChatService:
178
  {'search_results': search_results}
179
  )
180
 
181
- # Prepare the chat history for Gradio
182
  formatted_history = [
183
- (entry['user_input'], entry['response']) for entry in self.conversation_manager.get_history(session_id)
 
184
  ]
185
 
186
  return response, formatted_history, search_results
@@ -189,36 +203,7 @@ class ChatService:
189
  logger.error(f"Error in chat: {e}")
190
  raise
191
 
192
-
193
- async def search_all_sources(
194
- self,
195
- query: str,
196
- top_k: int = 3
197
- ) -> Dict[str, List[Dict[str, Any]]]:
198
- """Search across all available data sources"""
199
- try:
200
- print("-----------------------------")
201
- print("starting searches .... ")
202
-
203
- products = self.data_service.search(query, top_k)
204
- pdfs = self.pdf_service.search(query, top_k)
205
- faqs = self.faq_service.search_faqs(query, top_k)
206
-
207
-
208
- results = {
209
- 'products': products or [],
210
- 'documents': pdfs or [],
211
- 'faqs': faqs or []
212
- }
213
-
214
- print("Search results:", results)
215
- return results
216
-
217
- except Exception as e:
218
- logger.error(f"Error searching sources: {e}")
219
- return {'products': [], 'documents': [], 'faqs': []}
220
-
221
- async def generate_response(
222
  self,
223
  prompt: str,
224
  max_length: int = 1000
@@ -241,8 +226,7 @@ class ChatService:
241
  top_p=0.9,
242
  do_sample=True,
243
  no_repeat_ngram_size=3,
244
- early_stopping=False, # True num_beams=3, # Increase number of beams if beam search is needed
245
- #pad_token_id=self.tokenizer.eos_token_id, #eos_token_id
246
  )
247
 
248
  response = self.tokenizer.decode(
@@ -254,5 +238,4 @@ class ChatService:
254
 
255
  except Exception as e:
256
  logger.error(f"Error generating response: {e}")
257
- raise
258
-
 
1
+
2
+ #chat_service.py
3
  from typing import List, Dict, Any, Optional, Tuple
4
  from datetime import datetime
5
  import logging
6
  from config.config import settings
7
+ import asyncio
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
31
  'context': context
32
  })
33
 
 
34
  if len(self.conversations[session_id]) > self.max_history:
35
  self.conversations[session_id] = self.conversations[session_id][-self.max_history:]
36
 
 
42
  del self.conversations[session_id]
43
 
44
  class ChatService:
 
 
45
  def __init__(
46
  self,
47
  model_service,
 
56
  self.faq_service = faq_service
57
  self.conversation_manager = ConversationManager()
58
 
59
+ def search_all_sources(
60
+ self,
61
+ query: str,
62
+ top_k: int = 3
63
+ ) -> Dict[str, List[Dict[str, Any]]]:
64
+ """Search across all available data sources"""
65
+ try:
66
+ print("-----------------------------")
67
+ print("starting searches .... ")
68
+
69
+ products = self.data_service.search(query, top_k)
70
+ pdfs = self.pdf_service.search(query, top_k)
71
+ faqs = self.faq_service.search_faqs(query, top_k)
72
+
73
+ results = {
74
+ 'products': products or [],
75
+ 'documents': pdfs or [],
76
+ 'faqs': faqs or []
77
+ }
78
+
79
+ print("Search results:", results)
80
+ return results
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error searching sources: {e}")
84
+ return {'products': [], 'documents': [], 'faqs': []}
85
+
86
  def construct_system_prompt(self, context: str) -> str:
87
  """Constructs the system message."""
88
  return (
 
99
  chat_history: List[Tuple[str, str]],
100
  max_history_turns: int = 1
101
  ) -> str:
102
+ """Constructs the full prompt."""
103
+ system_message = self.construct_system_prompt(context)
104
+ prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>"
105
+
106
+ for user_msg, assistant_msg in chat_history[-max_history_turns:]:
107
+ prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>"
108
+ prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>"
109
+
110
+ prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
111
+ prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
112
+
113
+ return prompt
 
 
 
 
 
114
 
115
  def build_context(
116
+ self,
117
+ search_results: Dict[str, List[Dict[str, Any]]],
118
+ chat_history: List[Dict[str, Any]]
119
+ ) -> str:
120
+ """Build context for the model from search results and chat history"""
121
+ context_parts = []
122
+
123
+ # Add relevant products
124
+ if search_results.get('products'):
125
+ products = search_results['products'][:2] # Limit to top 2 products
126
+ for product in products:
127
+ context_parts.append(
128
+ f"Produkt: {product['Name']}\n"
129
+ f"Beschreibung: {product['Description']}\n"
130
+ f"Preis: {product['Price']}€\n"
131
+ f"Kategorie: {product['ProductCategory']}"
132
+ )
133
+
134
+ # Add relevant PDF content
135
+ if search_results.get('documents'):
136
+ docs = search_results['documents'][:2]
137
+ for doc in docs:
138
+ context_parts.append(
139
+ f"Aus Dokument '{doc['source']}' (Seite {doc['page']}):\n"
140
+ f"{doc['text']}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  )
142
+
143
+ # Add relevant FAQs
144
+ if search_results.get('faqs'):
145
+ faqs = search_results['faqs'][:2]
146
+ for faq in faqs:
147
+ context_parts.append(
148
+ f"FAQ:\n"
149
+ f"Frage: {faq['question']}\n"
150
+ f"Antwort: {faq['answer']}"
151
+ )
152
+
153
+ # Add recent chat history
154
+ if chat_history:
155
+ recent_history = chat_history[-3:] # Last 3 interactions
156
+ history_text = "\n".join(
157
+ f"User: {h['user_input']}\nAssistant: {h['response']}"
158
+ for h in recent_history
159
+ )
160
+ context_parts.append(f"Letzte Interaktionen:\n{history_text}")
161
+
162
+ print("\n\n".join(context_parts))
163
+ return "\n\n".join(context_parts)
164
 
165
  async def chat(
166
  self,
 
170
  ) -> Tuple[str, List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]]]:
171
  """Main chat method that coordinates the entire conversation flow."""
172
  try:
 
173
  if not isinstance(session_id, str):
174
  session_id = str(session_id)
175
 
 
176
  chat_history_raw = self.conversation_manager.get_history(session_id)
177
  chat_history = [
178
  (entry['user_input'], entry['response']) for entry in chat_history_raw
179
  ]
180
 
181
+ search_results = self.search_all_sources(user_input)
 
182
  print(search_results)
183
+
184
  context = self.build_context(search_results, chat_history_raw)
 
 
185
  prompt = self.construct_prompt(user_input, context, chat_history)
186
+ response = self.generate_response(prompt, max_length)
187
 
 
 
 
 
188
  self.conversation_manager.add_interaction(
189
  session_id,
190
  user_input,
 
192
  {'search_results': search_results}
193
  )
194
 
 
195
  formatted_history = [
196
+ (entry['user_input'], entry['response'])
197
+ for entry in self.conversation_manager.get_history(session_id)
198
  ]
199
 
200
  return response, formatted_history, search_results
 
203
  logger.error(f"Error in chat: {e}")
204
  raise
205
 
206
+ def generate_response(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  self,
208
  prompt: str,
209
  max_length: int = 1000
 
226
  top_p=0.9,
227
  do_sample=True,
228
  no_repeat_ngram_size=3,
229
+ early_stopping=False
 
230
  )
231
 
232
  response = self.tokenizer.decode(
 
238
 
239
  except Exception as e:
240
  logger.error(f"Error generating response: {e}")
241
+ raise