pathakDev10 commited on
Commit
50fadef
·
1 Parent(s): 6e41231
Files changed (1) hide show
  1. app.py +624 -25
app.py CHANGED
@@ -1,34 +1,633 @@
1
- from fastapi import FastAPI
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
- from langchain.llms import HuggingFacePipeline
4
- import torch
 
 
 
5
 
6
- app = FastAPI()
7
 
8
- # --- LLM Initialization using Hugging Face ---
9
- model_id = "Qwen/Qwen2.5-1.5B-Instruct"
10
- tokenizer = AutoTokenizer.from_pretrained(model_id)
11
- model = AutoModelForCausalLM.from_pretrained(
12
- model_id,
13
- device_map="auto",
14
- torch_dtype=torch.float16
15
- )
16
- generator = pipeline(
17
- "text-generation",
18
- model=model,
19
- tokenizer=tokenizer,
20
- max_length=256,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  temperature=0.3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
- llm = HuggingFacePipeline(pipeline=generator)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # Example endpoint using the new llm
26
  @app.post("/query")
27
  async def post_query(query: str):
28
- # Create a simple prompt structure
29
- prompt = f"Answer the following query:\n\n{query}\n"
30
- # Get the response from the LLM
31
- response = llm(prompt)
32
  return {"response": response}
33
 
34
- # (Keep your WebSocket endpoint and other code mostly unchanged)
 
1
+ import uuid
2
+ import threading
3
+ import asyncio
4
+ import json
5
+ import re
6
+ from datetime import datetime
7
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
8
 
9
+ # ------------------------ Chatbot Code (Unmodified) ------------------------
10
 
11
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
12
+ from langgraph.graph import StateGraph, START, END
13
+ # from langchain_ollama import ChatOllama
14
+ import faiss
15
+ from sentence_transformers import SentenceTransformer
16
+ import pickle
17
+ import numpy as np
18
+ from tools import extract_json_from_response, apply_filters_partial, rule_based_extract, format_property_data, estateKeywords
19
+ import random
20
+ from langchain_core.tools import tool
21
+ from langchain_core.callbacks import StreamingStdOutCallbackHandler, CallbackManager
22
+ from langchain_core.callbacks.base import BaseCallbackHandler
23
+
24
+ # ------------------------ Custom Callback for WebSocket Streaming ------------------------
25
+
26
+ class WebSocketStreamingCallbackHandler(BaseCallbackHandler):
27
+ def __init__(self, connection_id: str, loop):
28
+ self.connection_id = connection_id
29
+ self.loop = loop
30
+
31
+ def on_llm_new_token(self, token: str, **kwargs):
32
+ asyncio.run_coroutine_threadsafe(
33
+ manager_socket.send_message(self.connection_id, token),
34
+ self.loop
35
+ )
36
+
37
+
38
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
39
+
40
+ class ChatHuggingFace:
41
+ def __init__(self, model, token, temperature=0.3, streaming=False):
42
+ # Instead of using InferenceClient, load the model locally.
43
+ self.temperature = temperature
44
+ self.streaming = streaming
45
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
46
+ self.model = AutoModelForCausalLM.from_pretrained(model)
47
+ self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
48
+
49
+ def invoke(self, messages, config=None):
50
+ """
51
+ Mimics the ChatOllama.invoke interface.
52
+ In streaming mode, token-by-token output is sent via callbacks.
53
+ Otherwise, returns a single aggregated response.
54
+ """
55
+ config = config or {}
56
+ callbacks = config.get("callbacks", [])
57
+ aggregated_response = ""
58
+
59
+ # Build the prompt by concatenating messages in the expected format.
60
+ prompt = ""
61
+ for msg in messages:
62
+ role = msg.get("role", "")
63
+ content = msg.get("content", "")
64
+ if role == "system":
65
+ prompt += f"<|im_start|>system\n{content}\n<|im_end|>\n"
66
+ elif role == "user":
67
+ prompt += f"<|im_start|>user\n{content}\n<|im_end|>\n"
68
+ elif role == "assistant":
69
+ prompt += f"<|im_start|>assistant\n{content}\n<|im_end|>\n"
70
+
71
+ if self.streaming:
72
+ # Generate text locally.
73
+ full_output = self.pipeline(
74
+ prompt,
75
+ max_new_tokens=100,
76
+ do_sample=True,
77
+ temperature=self.temperature
78
+ )[0]['generated_text']
79
+ # Assume the pipeline returns the prompt + generated text.
80
+ new_text = full_output[len(prompt):]
81
+ # Simulate token-by-token streaming.
82
+ for token in new_text.split():
83
+ aggregated_response += token + " "
84
+ for cb in callbacks:
85
+ cb.on_llm_new_token(token=token + " ")
86
+ return type("AIMessage", (object,), {"content": aggregated_response.strip()})
87
+ else:
88
+ # Non-streaming mode.
89
+ response = self.pipeline(
90
+ prompt,
91
+ max_new_tokens=100,
92
+ do_sample=True,
93
+ temperature=self.temperature
94
+ )[0]['generated_text']
95
+ new_text = response[len(prompt):]
96
+ return type("AIMessage", (object,), {"content": new_text.strip()})
97
+
98
+
99
+
100
+
101
+ # ------------------------ LLM and Data Setup ------------------------
102
+ # model_name="qwen2.5:1.5b"
103
+ model_name="Qwen/Qwen2.5-1.5B-Instruct"
104
+ # llm = ChatOllama(model=model_name, temperature=0.3, streaming=True)
105
+ llm = ChatHuggingFace(
106
+ model=model_name,
107
+ # token=token,
108
  temperature=0.3,
109
+ streaming=True # or True, based on your needs
110
+ )
111
+
112
+ index = faiss.read_index("./faiss.index")
113
+ with open("./metadata.pkl", "rb") as f:
114
+ docs = pickle.load(f)
115
+ st_model = SentenceTransformer('all-MiniLM-L6-v2')
116
+
117
+
118
+ def make_system_prompt(suffix: str) -> str:
119
+ return (
120
+ "You are EstateGuru, a real estate expert created by Abhishek Pathak from SwavishTek. "
121
+ "Your role is to help customers buy properties using the available data. "
122
+ "Only use the provided data—do not make up any information. "
123
+ "The default currency is AED. If a query uses a different currency, convert the amount to AED "
124
+ "(for example, $10k becomes 36726.50 AED and $1 becomes 3.67 AED). "
125
+ "If a customer is interested in a property, wants to buy, or needs to contact an agent or customer care, "
126
+ "instruct them to call +91 8766268285."
127
+ f"\n{suffix}"
128
+ )
129
+
130
+ general_query_prompt = make_system_prompt(
131
+ "You are EstateGuru, a helpful real estate assistant. Answer the user's query accurately using the available data. "
132
+ "Do not invent any details or go beyond the real estate domain. "
133
+ "If the user shows interest in a property or contacting an agent, ask them to call +91 8766268285."
134
+ )
135
+
136
+
137
+
138
+ # ------------------------ Tool Definitions ------------------------
139
+
140
+ @tool
141
+ def extract_filters(query: str) -> dict:
142
+ """For extracting filters"""
143
+ # llm_local = ChatOllama(model=model_name, temperature=0.3)
144
+ llm_local = ChatHuggingFace(
145
+ model=model_name,
146
+ # token=token,
147
+ temperature=0.3,
148
+ streaming=False
149
+ )
150
+ system = (
151
+ "You are an expert in extracting filters from property-related queries. Your task is to extract and return only the keys explicitly mentioned in the query as a valid JSON object (starting with '{{' and ending with '}}'). Include only those keys that are directly present in the query.\n\n"
152
+ "The possible keys are:\n"
153
+ " - 'projectName': The name of the project.\n"
154
+ " - 'developerName': The developer's name.\n"
155
+ " - 'relationshipManager': The relationship manager.\n"
156
+ " - 'propertyAddress': The property address.\n"
157
+ " - 'surroundingArea': The area or nearby landmarks.\n"
158
+ " - 'propertyType': The type or configuration of the property.\n"
159
+ " - 'amenities': Any amenities mentioned.\n"
160
+ " - 'coveredParking': Parking availability.\n"
161
+ " - 'petRules': Pet policies.\n"
162
+ " - 'security': Security details.\n"
163
+ " - 'occupancyRate': Occupancy information.\n"
164
+ " - 'constructionImpact': Construction or its impact.\n"
165
+ " - 'propertySize': Size of the property.\n"
166
+ " - 'propertyView': View details.\n"
167
+ " - 'propertyCondition': Condition of the property.\n"
168
+ " - 'serviceCharges': Service or maintenance charges.\n"
169
+ " - 'ownershipType': Ownership type.\n"
170
+ " - 'totalCosts': A cost threshold or cost amount.\n"
171
+ " - 'paymentPlans': Payment or financing plans.\n"
172
+ " - 'expectedRentalYield': Expected rental yield.\n"
173
+ " - 'rentalHistory': Rental history.\n"
174
+ " - 'shortTermRentals': Short-term rental information.\n"
175
+ " - 'resalePotential': Resale potential.\n"
176
+ " - 'uniqueId': A unique identifier.\n\n"
177
+ "Important instructions regarding cost thresholds:\n"
178
+ " - If the query contains phrases like 'under 10k', 'below 2m', or 'less than 5k', interpret these as cost thresholds.\n"
179
+ " - Convert any shorthand cost values to pure numbers (for example, '10k' becomes 10000, '2m' becomes 2000000) and assign them to the key 'totalCosts'.\n"
180
+ " - Do not use 'propertySize' for cost thresholds.\n\n"
181
+ " - Default currency is AED, if user query have different currency symbol then convert to equivalent AED amount (eg. $10k becomes 36726.50, $1 becomes 3.67).\n\n"
182
+ "Example:\n"
183
+ " For the query: \"properties near dubai mall under 43k\"\n"
184
+ " The expected output should be:\n"
185
+ " {{ \"surroundingArea\": \"dubai mall\", \"totalCosts\": 43000 }}\n\n"
186
+ "Return ONLY a valid JSON object with the extracted keys and their corresponding values, with no additional text."
187
+ )
188
+
189
+ human_str = f"Here is the query:\n{query}"
190
+ filter_prompt = [
191
+ {"role": "system", "content": system},
192
+ {"role": "user", "content": human_str},
193
+ ]
194
+ response = llm_local.invoke(messages=filter_prompt)
195
+ response_text = response.content if isinstance(response, AIMessage) else str(response)
196
+ try:
197
+ model_filters = extract_json_from_response(response_text)
198
+ except Exception as e:
199
+ print(f"JSON parsing error: {e}")
200
+ model_filters = {}
201
+ rule_filters = rule_based_extract(query)
202
+ print("Rule-based extraction:", rule_filters)
203
+ final_filters = {**model_filters, **rule_filters}
204
+ print("Final extraction:", final_filters)
205
+ return {"filters": final_filters}
206
+
207
+
208
+ @tool
209
+ def determine_route(query: str) -> dict:
210
+ """For determining route using enhanced prompt and fallback logic."""
211
+ # Define a set of keywords that are strong indicators of a real estate query.
212
+ real_estate_keywords = estateKeywords
213
+
214
+ # Check if the query includes any of the positive signals.
215
+ pattern = re.compile("|".join(re.escape(keyword) for keyword in real_estate_keywords), re.IGNORECASE)
216
+ positive_signal = bool(pattern.search(query))
217
+
218
+ # Proceed with LLM classification regardless, but use the positive signal in fallback.
219
+ # llm_local = ChatOllama(model=model_name, temperature=0.3)
220
+ llm_local = ChatHuggingFace(
221
+ model=model_name,
222
+ # token=token,
223
+ temperature=0.3,
224
+ streaming=False
225
+ )
226
+ transform_suggest_to_list = query.lower().replace("suggest ", "list ", -1)
227
+ system = """
228
+ Classify the user query as:
229
+
230
+ - **"search"**: if it requests property listings with specific filters (e.g., location, price, property type like "2bhk", service charges, pet policies, etc.).
231
+ - **"suggest"**: if it asks for property suggestions without filters.
232
+ - **"detail"**: if it is asking for more information about a previously provided property (e.g., "tell me more about property 5" or "I want more information regarding 4BHK").
233
+ - **"general"**: for all other real estate-related questions.
234
+ - **"out_of_domain"**: if the query is not related to real estate (for example, tourist attractions, restaurants, etc.).
235
+
236
+ Keep in mind that queries mentioning terms like "service charge", "allow pets", "pet rules", etc., are considered real estate queries.
237
+
238
+ Return only the keyword: search, suggest, detail, general, or out_of_domain.
239
+ """
240
+ human_str = f"Here is the query:\n{transform_suggest_to_list}"
241
+ filter_prompt = [
242
+ {"role": "system", "content": system},
243
+ {"role": "user", "content": human_str},
244
+ ]
245
+ response = llm_local.invoke(messages=filter_prompt)
246
+ response_text = response.content if isinstance(response, AIMessage) else str(response)
247
+ route_value = str(response_text).strip().lower()
248
+
249
+ # Fallback: if no positive real estate signal is found, override to out_of_domain.
250
+ # if not positive_signal:
251
+ # route_value = "out_of_domain"
252
+
253
+ # Fallback
254
+ detail_phrases = [
255
+ "more information",
256
+ "tell me more",
257
+ "more details",
258
+ "give me more details",
259
+ "I need more details",
260
+ "can you provide more details",
261
+ "additional details",
262
+ "further information",
263
+ "expand on that",
264
+ "explain further",
265
+ "elaborate more",
266
+ "more specifics",
267
+ "I want to know more",
268
+ "could you elaborate",
269
+ "need more info",
270
+ "provide more details",
271
+ "detail it further",
272
+ "in-depth information",
273
+ "break it down further",
274
+ "further explanation"
275
+ ]
276
+
277
+ if any(phrase in query.lower() for phrase in detail_phrases):
278
+ route_value = "detail"
279
+
280
+ if route_value not in {"search", "suggest", "detail", "general", "out_of_domain"}:
281
+ route_value = "general"
282
+ if route_value == "out_of_domain" and positive_signal:
283
+ route_value = "general"
284
+
285
+ if route_value == "out_of_domain":
286
+ # If positive real estate signal exists, treat it as "general".
287
+ route_value = "general" if positive_signal else "out_of_domain"
288
+
289
+ return {"route": route_value}
290
+
291
+
292
+ # ------------------------ Workflow Setup ------------------------
293
+
294
+ workflow = StateGraph(state_schema=dict)
295
+
296
+ def route_query(state: dict) -> dict:
297
+ new_state = state.copy()
298
+ try:
299
+ new_state["route"] = determine_route.invoke(new_state.get("query", "")).get("route", "general")
300
+ print(new_state["route"])
301
+ except Exception as e:
302
+ print(f"Routing error: {e}")
303
+ new_state["route"] = "general"
304
+ return new_state
305
+
306
+ def hybrid_extract(state: dict) -> dict:
307
+ new_state = state.copy()
308
+ new_state["filters"] = extract_filters.invoke(new_state.get("query", "")).get("filters", {})
309
+ return new_state
310
+
311
+ def search_faiss(state: dict) -> dict:
312
+ new_state = state.copy()
313
+ query_embedding = st_model.encode([state["query"]])
314
+ _, indices = index.search(query_embedding.astype(np.float32), 5)
315
+ new_state["faiss_results"] = [docs[idx] for idx in indices[0] if idx < len(docs)]
316
+ return new_state
317
+
318
+ def apply_filters(state: dict) -> dict:
319
+ new_state = state.copy()
320
+ new_state["final_results"] = apply_filters_partial(state["faiss_results"], state.get("filters", {}))
321
+ return new_state
322
+
323
+ def suggest_properties(state: dict) -> dict:
324
+ new_state = state.copy()
325
+ new_state["suggestions"] = random.sample(docs, 5)
326
+ return new_state
327
+
328
+ def handle_out_of_domain(state: dict) -> dict:
329
+ new_state = state.copy()
330
+ new_state["response"] = "I only handle real estate inquiries. Please ask a question related to properties."
331
+ return new_state
332
+
333
+
334
+
335
+ def generate_response(state: dict) -> dict:
336
+ new_state = state.copy()
337
+ detail_query_flag = False
338
+
339
+ # --- Disambiguate specific property requests using property number ---
340
+ property_match = re.search(r"(?:the\s+)?property\s*(\d+)\b", state.get("query", ""), re.IGNORECASE)
341
+ if property_match and new_state.get("current_properties"):
342
+ try:
343
+ index_requested = int(property_match.group(1)) - 1
344
+ if 0 <= index_requested < len(new_state["current_properties"]):
345
+ new_state["current_properties"] = [new_state["current_properties"][index_requested]]
346
+ detail_query_flag = True
347
+ new_state["detail_query"] = True
348
+ except Exception as e:
349
+ print(f"Property selection error: {e}")
350
+
351
+ # Construct messages for the LLM.
352
+ messages = []
353
+
354
+ # Add the general query prompt.
355
+ messages.append(SystemMessage(content=general_query_prompt))
356
+ # If this is a detail query, add a system message that forces a detailed answer.
357
+ if detail_query_flag:
358
+ messages.append(SystemMessage(content=(
359
+ "This is a detail query. Please provide detailed information about the property below. "
360
+ "Do not generate a new list of properties; only use the provided property details to answer the query. "
361
+ "Focus on answering the specific question (for example, whether pets are allowed)."
362
+ )))
363
+
364
+
365
+ # Provide the current property context.
366
+ if new_state.get("current_properties"):
367
+ property_context = format_property_data(new_state["current_properties"])
368
+ messages.insert(0, SystemMessage(content="Available Property:\n" + property_context))
369
+
370
+ # Add the conversation history.
371
+ for msg in state.get("messages", []):
372
+ if msg["role"] == "user":
373
+ messages.append(HumanMessage(content=msg["content"]))
374
+ else:
375
+ messages.append(AIMessage(content=msg["content"]))
376
+
377
+ # Instruction for response.
378
+ messages.append(SystemMessage(content=(
379
+ "When responding, use only the provided property details to answer the user's specific question about the property."
380
+ )))
381
+
382
+ # Invoke the LLM with the constructed messages.
383
+ connection_id = state.get("connection_id")
384
+ loop = state.get("loop")
385
+ if connection_id and loop:
386
+ callback_manager = CallbackManager([WebSocketStreamingCallbackHandler(connection_id, loop)])
387
+ _ = llm.invoke(
388
+ messages=messages,
389
+ config={"callbacks": callback_manager}
390
+ )
391
+ new_state["response"] = ""
392
+ else:
393
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
394
+ response = llm.invoke(
395
+ messages=messages,
396
+ config={"callbacks": callback_manager}
397
+ )
398
+ new_state["response"] = response.content if isinstance(response, AIMessage) else str(response)
399
+
400
+ return new_state
401
+
402
+
403
+
404
+ def format_final_response(state: dict) -> dict:
405
+ new_state = state.copy()
406
+ # Only override the current_properties if this is NOT a detail query.
407
+ if not state.get("detail_query", False):
408
+ if state.get("route") in ["search", "suggest"]:
409
+ if "final_results" in state:
410
+ new_state["current_properties"] = state["final_results"]
411
+ elif "suggestions" in state:
412
+ new_state["current_properties"] = state["suggestions"]
413
+
414
+ # Then format the response based on the (possibly filtered) current_properties.
415
+ if new_state.get("current_properties"):
416
+ formatted = []
417
+ for idx, prop in enumerate(new_state["current_properties"], 1):
418
+ cost = prop.get("totalCosts", "N/A")
419
+ cost_str = f"{cost:,}" if isinstance(cost, (int, float)) else cost
420
+ formatted.append(
421
+ f"{idx}. Type: {prop['propertyType']}, Cost: AED {cost_str}, "
422
+ f"Size: {prop.get('propertySize', 'N/A')}, Amenities: {', '.join(map(str, prop.get('amenities', []))) if prop.get('amenities') else 'N/A'}, "
423
+ f"Rental Yield: {prop.get('expectedRentalYield', 'N/A')}, "
424
+ f"Ownership: {prop.get('ownershipType', 'N/A')}\n"
425
+ )
426
+ aggregated_response = "Here are the property details:\n" + "\n".join(formatted)
427
+ connection_id = state.get("connection_id")
428
+ loop = state.get("loop")
429
+ if connection_id and loop:
430
+ import time
431
+ tokens = aggregated_response.split(" ")
432
+ for token in tokens:
433
+ asyncio.run_coroutine_threadsafe(
434
+ manager_socket.send_message(connection_id, token + " "),
435
+ loop
436
+ )
437
+ time.sleep(0.05)
438
+ new_state["response"] = ""
439
+ else:
440
+ new_state["response"] = aggregated_response
441
+ elif "response" in new_state:
442
+ new_state["response"] = str(new_state["response"])
443
+ return new_state
444
+
445
+
446
+
447
+
448
+ nodes = [
449
+ ("route_query", route_query),
450
+ ("hybrid_extract", hybrid_extract),
451
+ ("faiss_search", search_faiss),
452
+ ("apply_filters", apply_filters),
453
+ ("suggest_properties", suggest_properties),
454
+ ("handle_out_of_domain", handle_out_of_domain),
455
+ ("generate_response", generate_response),
456
+ ("format_response", format_final_response)
457
+ ]
458
+
459
+ for name, node in nodes:
460
+ workflow.add_node(name, node)
461
+
462
+ workflow.add_edge(START, "route_query")
463
+ workflow.add_conditional_edges(
464
+ "route_query",
465
+ lambda state: state.get("route", "general"),
466
+ {
467
+ "search": "hybrid_extract",
468
+ "suggest": "suggest_properties",
469
+ "detail": "generate_response",
470
+ "general": "generate_response",
471
+ "out_of_domain": "handle_out_of_domain"
472
+ }
473
  )
474
+ workflow.add_edge("hybrid_extract", "faiss_search")
475
+ workflow.add_edge("faiss_search", "apply_filters")
476
+ workflow.add_edge("apply_filters", "format_response")
477
+ workflow.add_edge("suggest_properties", "format_response")
478
+ workflow.add_edge("generate_response", "format_response")
479
+ workflow.add_edge("handle_out_of_domain", "format_response")
480
+ workflow.add_edge("format_response", END)
481
+
482
+ workflow_app = workflow.compile()
483
+
484
+ # ------------------------ Conversation Manager ------------------------
485
+
486
+ class ConversationManager:
487
+ def __init__(self):
488
+ self.conversation_history = []
489
+ self.current_properties = []
490
+
491
+ def _add_message(self, role: str, content: str):
492
+ self.conversation_history.append({
493
+ "role": role,
494
+ "content": content,
495
+ "timestamp": datetime.now().isoformat()
496
+ })
497
+
498
+ def process_query(self, query: str) -> str:
499
+ # Reset context on greetings to avoid using off-domain history
500
+ if query.strip().lower() in {"hi", "hello", "hey"}:
501
+ self.conversation_history = []
502
+ self.current_properties = []
503
+ greeting_response = "Hello! How can I assist you today with your real estate inquiries?"
504
+ self._add_message("assistant", greeting_response)
505
+ return greeting_response
506
+
507
+ try:
508
+ self._add_message("user", query)
509
+ initial_state = {
510
+ "messages": self.conversation_history.copy(),
511
+ "query": query,
512
+ "route": "general",
513
+ "filters": {},
514
+ "current_properties": self.current_properties
515
+ }
516
+ for event in workflow_app.stream(initial_state, stream_mode="values"):
517
+ final_state = event
518
+ if 'final_results' in final_state:
519
+ self.current_properties = final_state['final_results']
520
+ elif 'suggestions' in final_state:
521
+ self.current_properties = final_state['suggestions']
522
+ if final_state.get("route") == "general":
523
+ response_text = final_state.get("response", "")
524
+ self._add_message("assistant", response_text)
525
+ return response_text
526
+ else:
527
+ response = final_state.get("response", "I couldn't process that request.")
528
+ self._add_message("assistant", response)
529
+ return response
530
+ except Exception as e:
531
+ print(f"Processing error: {e}")
532
+ return "Sorry, I encountered an error processing your request."
533
+
534
+ conversation_managers = {}
535
+
536
+ # ------------------------ FastAPI Backend with WebSockets ------------------------
537
+
538
+ app = FastAPI()
539
+
540
+ class ConnectionManager:
541
+ def __init__(self):
542
+ self.active_connections = {}
543
+
544
+ async def connect(self, websocket: WebSocket):
545
+ await websocket.accept()
546
+ connection_id = str(uuid.uuid4())
547
+ self.active_connections[connection_id] = websocket
548
+ print(f"New connection: {connection_id}")
549
+ return connection_id
550
+
551
+ def disconnect(self, connection_id: str):
552
+ if connection_id in self.active_connections:
553
+ del self.active_connections[connection_id]
554
+ print(f"Disconnected: {connection_id}")
555
+
556
+ async def send_message(self, connection_id: str, message: str):
557
+ websocket = self.active_connections.get(connection_id)
558
+ if websocket:
559
+ await websocket.send_text(message)
560
+
561
+ manager_socket = ConnectionManager()
562
+
563
+
564
+
565
+ def stream_query(query: str, connection_id: str, loop):
566
+ conv_manager = conversation_managers.get(connection_id)
567
+ if conv_manager is None:
568
+ print(f"No conversation manager found for connection {connection_id}")
569
+ return
570
+
571
+ # Check for greetings and handle them immediately
572
+ if query.strip().lower() in {"hi", "hello", "hey"}:
573
+ conv_manager.conversation_history = []
574
+ conv_manager.current_properties = []
575
+ greeting_response = "Hello! How can I assist you today with your real estate inquiries?"
576
+ conv_manager._add_message("assistant", greeting_response)
577
+ asyncio.run_coroutine_threadsafe(
578
+ manager_socket.send_message(connection_id, greeting_response),
579
+ loop
580
+ )
581
+ return
582
+
583
+ conv_manager._add_message("user", query)
584
+ initial_state = {
585
+ "messages": conv_manager.conversation_history.copy(),
586
+ "query": query,
587
+ "route": "general",
588
+ "filters": {},
589
+ "current_properties": conv_manager.current_properties,
590
+ "connection_id": connection_id,
591
+ "loop": loop
592
+ }
593
+ try:
594
+ workflow_app.invoke(initial_state)
595
+ except Exception as e:
596
+ error_msg = f"Error processing query: {str(e)}"
597
+ asyncio.run_coroutine_threadsafe(
598
+ manager_socket.send_message(connection_id, error_msg),
599
+ loop
600
+ )
601
+
602
+
603
+
604
+
605
+ @app.websocket("/ws")
606
+ async def websocket_endpoint(websocket: WebSocket):
607
+ connection_id = await manager_socket.connect(websocket)
608
+ conversation_managers[connection_id] = ConversationManager()
609
+ try:
610
+ while True:
611
+ query = await websocket.receive_text()
612
+ loop = asyncio.get_event_loop()
613
+ threading.Thread(
614
+ target=stream_query,
615
+ args=(query, connection_id, loop),
616
+ daemon=True
617
+ ).start()
618
+ except WebSocketDisconnect:
619
+ conv_manager = conversation_managers.get(connection_id)
620
+ if conv_manager:
621
+ filename = f"conversations/conversation_{connection_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
622
+ with open(filename, "w") as f:
623
+ json.dump(conv_manager.conversation_history, f, indent=4)
624
+ del conversation_managers[connection_id]
625
+ manager_socket.disconnect(connection_id)
626
 
 
627
  @app.post("/query")
628
  async def post_query(query: str):
629
+ conv_manager = ConversationManager()
630
+ response = conv_manager.process_query(query)
 
 
631
  return {"response": response}
632
 
633
+