pathakDev10 commited on
Commit
d54731f
Β·
1 Parent(s): c01f047

performance

Browse files
__pycache__/app.cpython-312.pyc CHANGED
Binary files a/__pycache__/app.cpython-312.pyc and b/__pycache__/app.cpython-312.pyc differ
 
app.py CHANGED
@@ -9,21 +9,23 @@ import pickle
9
  import numpy as np
10
  import requests # For llama.cpp server calls
11
  from datetime import datetime
12
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
13
- from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
14
  from langgraph.graph import StateGraph, START, END
15
  import faiss
16
  from sentence_transformers import SentenceTransformer
17
  from tools import extract_json_from_response, apply_filters_partial, rule_based_extract, structured_property_data, estateKeywords, sendTokenViaSocket
18
- from langchain_core.prompts import ChatPromptTemplate
19
  from langchain_core.tools import tool
20
- from langchain_core.callbacks import StreamingStdOutCallbackHandler, CallbackManager
21
  from langchain_core.callbacks.base import BaseCallbackHandler
22
 
23
  import os
24
  from fastapi.responses import PlainTextResponse
25
- from fastapi import FastAPI, Request
26
  from fastapi.staticfiles import StaticFiles
 
 
 
 
27
  # ------------------------ Model Inference Wrapper ------------------------
28
 
29
  class ChatQwen:
@@ -48,6 +50,7 @@ class ChatQwen:
48
  self.max_new_tokens = max_new_tokens
49
  self.callbacks = callbacks
50
  self.use_server = use_server
 
51
 
52
  if self.use_server:
53
  # Use remote llama.cpp server – provide its URL.
@@ -57,90 +60,164 @@ class ChatQwen:
57
  if not model_path:
58
  raise ValueError("Local mode requires a valid model_path to the gguf file.")
59
  from llama_cpp import Llama # assumes llama-cpp-python is installed
60
- self.model = Llama(
61
- model_path=model_path,
62
- temperature=self.temperature,
63
- # n_ctx=512,
64
- n_ctx=8192,
65
- n_threads=4, # Adjust as needed
66
- batch_size=512,
67
- verbose=False,
68
- )
69
-
70
- def build_prompt(self, messages: list) -> str:
71
- """Build Qwen-compatible prompt with special tokens."""
72
- prompt = ""
73
- for msg in messages:
74
- role = msg["role"]
75
- content = msg["content"]
76
- if role == "system":
77
- prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
78
- elif role == "user":
79
- prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
80
- elif role == "assistant":
81
- prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
82
- prompt += "<|im_start|>assistant\n"
83
- return prompt
84
-
85
- def generate_text(self, messages: list) -> str:
86
- prompt = self.build_prompt(messages)
87
- stop_tokens = ["<|im_end|>", "\n"] # Qwen's stop sequences
88
-
89
- if self.use_server:
90
- payload = {
91
- "prompt": prompt,
92
- "max_tokens": self.max_new_tokens,
93
- "temperature": self.temperature,
94
- "stream": self.streaming,
95
- "stop": stop_tokens # Add stop tokens to server request
96
- }
97
- if self.streaming:
98
- response = requests.post(f"{self.server_url}/generate", json=payload, stream=True)
99
- generated_text = ""
100
- for line in response.iter_lines():
101
- if line:
102
- token = line.decode("utf-8")
103
- # Check for stop tokens in stream
104
- if any(stop in token for stop in stop_tokens):
105
- break
106
- generated_text += token
107
- if self.callbacks:
108
- for callback in self.callbacks:
109
- callback.on_llm_new_token(token)
110
- return generated_text
111
- else:
112
- response = requests.post(f"{self.server_url}/generate", json=payload)
113
- return response.json().get("generated_text", "")
114
- else:
115
- # Local llama.cpp inference
116
- if self.streaming:
117
- stream = self.model.create_completion(
118
- prompt=prompt,
119
- max_tokens=self.max_new_tokens,
120
  temperature=self.temperature,
121
- stream=True,
122
- stop=stop_tokens
 
 
 
 
 
 
 
 
123
  )
124
- generated_text = ""
125
- for token_chunk in stream:
126
- token_text = token_chunk["choices"][0]["text"]
127
- # Stop early if we detect end token
128
- if any(stop in token_text for stop in stop_tokens):
129
- break
130
- generated_text += token_text
131
- if self.callbacks:
132
- for callback in self.callbacks:
133
- callback.on_llm_new_token(token_text)
134
- return generated_text
135
  else:
136
- result = self.model.create_completion(
137
- prompt=prompt,
138
- max_tokens=self.max_new_tokens,
 
 
 
 
 
 
139
  temperature=self.temperature,
140
- stop=stop_tokens
 
141
  )
142
- return result["choices"][0]["text"]
143
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def invoke(self, messages: list, config: dict = None) -> AIMessage:
145
  config = config or {}
146
  callbacks = config.get("callbacks", self.callbacks)
@@ -184,6 +261,13 @@ llm = ChatQwen(
184
  # server_url="http://localhost:8000" # Uncomment and set if using server mode.
185
  )
186
 
 
 
 
 
 
 
 
187
  # ------------------------ FAISS and Sentence Transformer Setup ------------------------
188
 
189
  index = faiss.read_index("./faiss.index")
@@ -212,9 +296,10 @@ general_query_prompt = make_system_prompt(
212
  # ------------------------ Tool Definitions ------------------------
213
 
214
  @tool
 
215
  def extract_filters(query: str) -> dict:
216
  """Extract filters from the query."""
217
- llm_local = ChatQwen(temperature=0.3, streaming=False, use_server=False, model_path=model_path)
218
  system = (
219
  "You are an expert in extracting filters from property-related queries. Your task is to extract and return only the keys explicitly mentioned in the query as a valid JSON object (starting with '{' and ending with '}'). Include only those keys that are directly present in the query.\n\n"
220
  "The possible keys are:\n"
@@ -259,7 +344,7 @@ def extract_filters(query: str) -> dict:
259
  {"role": "system", "content": system},
260
  {"role": "user", "content": human_str},
261
  ]
262
- response = llm_local.invoke(messages=filter_prompt)
263
  response_text = response.content if isinstance(response, AIMessage) else str(response)
264
  try:
265
  model_filters = extract_json_from_response(response_text)
@@ -274,13 +359,14 @@ def extract_filters(query: str) -> dict:
274
 
275
 
276
  @tool
 
277
  def determine_route(query: str) -> dict:
278
  """Determine the route (search, suggest, detail, general, out_of_domain) for the query."""
279
  real_estate_keywords = estateKeywords
280
  pattern = re.compile("|".join(re.escape(keyword) for keyword in real_estate_keywords), re.IGNORECASE)
281
  positive_signal = bool(pattern.search(query))
282
 
283
- llm_local = ChatQwen(temperature=0.3, streaming=False, use_server=False, model_path=model_path)
284
  transform_suggest_to_list = query.lower().replace("suggest ", "list ", -1)
285
  system = """
286
  Classify the user query as:
@@ -302,7 +388,7 @@ def determine_route(query: str) -> dict:
302
  {"role": "user", "content": human_str},
303
  ]
304
 
305
- response = llm_local.invoke(messages=router_prompt)
306
  response_text = response.content if isinstance(response, AIMessage) else str(response)
307
  route_value = str(response_text).strip().lower()
308
 
@@ -410,8 +496,8 @@ def generate_response(state: dict) -> dict:
410
  messages.append({"role": "system", "content": "When responding, use only the provided property details."})
411
 
412
  # Add conversation history
413
- # Truncate conversation history (last 4 exchanges)
414
- truncated_history = state.get("messages", [])[-8:] # Last 4 user+assistant pairs
415
  for msg in truncated_history:
416
  messages.append({"role": msg["role"], "content": msg["content"]})
417
 
@@ -462,23 +548,9 @@ def format_final_response(state: dict) -> dict:
462
  new_state["current_properties"] = state["current_properties"]
463
 
464
 
465
- # print("state: ", json.dumps(new_state), "\n\n")
466
- # Format the property details if available.
467
- # if new_state.get("current_properties"):
468
  if state.get("route") in ["search", "suggest"] and new_state.get("current_properties"):
469
  formatted = structured_property_data(state=new_state)
470
-
471
- # for idx, prop in enumerate(new_state["current_properties"], 1):
472
- # cost = prop.get("totalCosts", "N/A")
473
- # cost_str = f"{cost:,}" if isinstance(cost, (int, float)) else cost
474
- # formatted.append(
475
- # f"{idx}. Type: {prop['propertyType']}, Cost: AED {cost_str}, "
476
- # f"Size: {prop.get('propertySize', 'N/A')}, Amenities: {', '.join(map(str, prop.get('amenities', []))) if prop.get('amenities') else 'N/A'}, "
477
- # f"Rental Yield: {prop.get('expectedRentalYield', 'N/A')}, "
478
- # f"Ownership: {prop.get('ownershipType', 'N/A')}\n"
479
- # )
480
  aggregated_response = "Here are the property details:\n" + "\n".join(formatted)
481
- # print(aggregated_response)
482
 
483
  connection_id = state.get("connection_id")
484
  loop = state.get("loop")
@@ -727,6 +799,8 @@ async def websocket_endpoint(websocket: WebSocket):
727
  del conversation_managers[connection_id]
728
  manager_socket.disconnect(connection_id)
729
 
 
 
730
  @app.post("/query")
731
  async def post_query(query: str):
732
  conv_manager = ConversationManager()
@@ -761,3 +835,17 @@ async def check_model_middleware(request: Request, call_next):
761
  @app.get("/")
762
  async def home():
763
  return PlainTextResponse("Space is running. Model ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
  import requests # For llama.cpp server calls
11
  from datetime import datetime
12
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, Request
13
+ from langchain_core.messages import AIMessage
14
  from langgraph.graph import StateGraph, START, END
15
  import faiss
16
  from sentence_transformers import SentenceTransformer
17
  from tools import extract_json_from_response, apply_filters_partial, rule_based_extract, structured_property_data, estateKeywords, sendTokenViaSocket
 
18
  from langchain_core.tools import tool
19
+ from langchain_core.callbacks import StreamingStdOutCallbackHandler
20
  from langchain_core.callbacks.base import BaseCallbackHandler
21
 
22
  import os
23
  from fastapi.responses import PlainTextResponse
 
24
  from fastapi.staticfiles import StaticFiles
25
+ from functools import lru_cache
26
+ from contextlib import asynccontextmanager
27
+
28
+
29
  # ------------------------ Model Inference Wrapper ------------------------
30
 
31
  class ChatQwen:
 
50
  self.max_new_tokens = max_new_tokens
51
  self.callbacks = callbacks
52
  self.use_server = use_server
53
+ self.is_hf_space = os.environ.get('SPACE_ID') is not None
54
 
55
  if self.use_server:
56
  # Use remote llama.cpp server – provide its URL.
 
60
  if not model_path:
61
  raise ValueError("Local mode requires a valid model_path to the gguf file.")
62
  from llama_cpp import Llama # assumes llama-cpp-python is installed
63
+ # self.model = Llama(
64
+ # model_path=model_path,
65
+ # temperature=self.temperature,
66
+ # # n_ctx=512,
67
+ # n_ctx=8192,
68
+ # n_threads=4, # Adjust as needed
69
+ # batch_size=512,
70
+ # verbose=False,
71
+ # )
72
+ # Update Llama initialization:
73
+ if self.is_hf_space:
74
+ self.model = Llama(
75
+ model_path=model_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  temperature=self.temperature,
77
+ n_ctx=1024, # Reduced from 8192
78
+ n_threads=2, # Never exceed 2 threads on free tier
79
+ n_batch=128, # Smaller batch size for low RAM
80
+ use_mmap=True, # Essential for memory mapping
81
+ use_mlock=False, # Disable memory locking
82
+ low_vram=True, # Special low-memory mode
83
+ vocab_only=False,
84
+ n_gqa=2, # Grouped-query attention for 1.5B model
85
+ rope_freq_base=10000,
86
+ logits_all=False
87
  )
 
 
 
 
 
 
 
 
 
 
 
88
  else:
89
+ self.model = Llama(
90
+ model_path=model_path,
91
+ n_gpu_layers=20, # Offload 20 layers to GPU (adjust based on VRAM)
92
+ n_threads=3, # leave 1
93
+ n_threads_batch=3,
94
+ batch_size=256,
95
+ main_gpu=0, # Use first GPU
96
+ use_mmap=True,
97
+ use_mlock=False,
98
  temperature=self.temperature,
99
+ n_ctx=2048, # Reduced context for lower memory usage
100
+ verbose=False
101
  )
102
+
103
+ if not self.use_server:
104
+ self.model.tokenize(b"Warmup") # Pre-load model
105
+ self.model.create_completion("Warmup", max_tokens=1)
106
+
107
+ # def build_prompt(self, messages: list) -> str:
108
+ # """Build Qwen-compatible prompt with special tokens."""
109
+ # prompt = ""
110
+ # for msg in messages:
111
+ # role = msg["role"]
112
+ # content = msg["content"]
113
+ # if role == "system":
114
+ # prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
115
+ # elif role == "user":
116
+ # prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
117
+ # elif role == "assistant":
118
+ # prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
119
+ # prompt += "<|im_start|>assistant\n"
120
+ # return prompt
121
+
122
+ @lru_cache(maxsize=2)
123
+ def build_prompt(self, messages: list) -> str:
124
+ """Optimized prompt builder with string join"""
125
+ return "".join(
126
+ f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
127
+ for msg in messages
128
+ ) + "<|im_start|>assistant\n"
129
+
130
+ def generate_text(self, messages: list) -> str:
131
+ try:
132
+ prompt = self.build_prompt(messages)
133
+ stop_tokens = ["<|im_end|>", "\n"] # Qwen's stop sequences
134
+
135
+ if self.use_server:
136
+ payload = {
137
+ "prompt": prompt,
138
+ "max_tokens": self.max_new_tokens,
139
+ "temperature": self.temperature,
140
+ "stream": self.streaming,
141
+ "stop": stop_tokens # Add stop tokens to server request
142
+ }
143
+ if self.streaming:
144
+ response = requests.post(f"{self.server_url}/generate", json=payload, stream=True)
145
+ generated_text = ""
146
+ for line in response.iter_lines():
147
+ if line:
148
+ token = line.decode("utf-8")
149
+ # Check for stop tokens in stream
150
+ if any(stop in token for stop in stop_tokens):
151
+ break
152
+ generated_text += token
153
+ if self.callbacks:
154
+ for callback in self.callbacks:
155
+ callback.on_llm_new_token(token)
156
+ return generated_text
157
+ else:
158
+ response = requests.post(f"{self.server_url}/generate", json=payload)
159
+ return response.json().get("generated_text", "")
160
+ else:
161
+ # Local llama.cpp inference
162
+ if self.streaming:
163
+ if self.is_hf_space:
164
+ stream = self.model.create_completion(
165
+ prompt=prompt,
166
+ max_tokens=256, # Reduced from 512
167
+ temperature=0.3,
168
+ stream=True,
169
+ stop=stop_tokens,
170
+ repeat_penalty=1.15,
171
+ frequency_penalty=0.2,
172
+ mirostat_mode=2, # Better for low-resource
173
+ mirostat_tau=3.0,
174
+ mirostat_eta=0.1
175
+ )
176
+ else:
177
+ stream = self.model.create_completion(
178
+ prompt=prompt,
179
+ max_tokens=self.max_new_tokens,
180
+ temperature=self.temperature,
181
+ stream=True,
182
+ stop=stop_tokens,
183
+ repeat_penalty=1.1, # Reduce repetition for faster generation
184
+ tfs_z=0.5 # Tail-free sampling for efficiency
185
+ )
186
+
187
+
188
+ generated_text = ""
189
+ for token_chunk in stream:
190
+ token_text = token_chunk["choices"][0]["text"]
191
+ # Stop early if we detect end token
192
+ if any(stop in token_text for stop in stop_tokens):
193
+ break
194
+ generated_text += token_text
195
+ if self.callbacks:
196
+ for callback in self.callbacks:
197
+ callback.on_llm_new_token(token_text)
198
+ return generated_text
199
+ else:
200
+ result = self.model.create_completion(
201
+ prompt=prompt,
202
+ max_tokens=self.max_new_tokens,
203
+ temperature=self.temperature,
204
+ stop=stop_tokens
205
+ )
206
+ return result["choices"][0]["text"]
207
+ except Exception as e:
208
+ if "out of memory" in str(e).lower() and self.is_hf_space:
209
+ return self.fallback_generate(messages)
210
+
211
+ def fallback_generate(self, messages):
212
+ """Simpler generation for OOM situations"""
213
+ return self.model.create_completion(
214
+ prompt=self.build_prompt(messages),
215
+ max_tokens=128,
216
+ temperature=0.3,
217
+ stream=False,
218
+ stop=["<|im_end|>", "\n"]
219
+ )["choices"][0]["text"]
220
+
221
  def invoke(self, messages: list, config: dict = None) -> AIMessage:
222
  config = config or {}
223
  callbacks = config.get("callbacks", self.callbacks)
 
261
  # server_url="http://localhost:8000" # Uncomment and set if using server mode.
262
  )
263
 
264
+ llm_no_stream = ChatQwen(
265
+ temperature=0.3,
266
+ streaming=False,
267
+ use_server=False,
268
+ model_path=model_path,
269
+ )
270
+
271
  # ------------------------ FAISS and Sentence Transformer Setup ------------------------
272
 
273
  index = faiss.read_index("./faiss.index")
 
296
  # ------------------------ Tool Definitions ------------------------
297
 
298
  @tool
299
+ @lru_cache(maxsize=50,typed=False)
300
  def extract_filters(query: str) -> dict:
301
  """Extract filters from the query."""
302
+ # llm_local = ChatQwen(temperature=0.3, streaming=False, use_server=False, model_path=model_path)
303
  system = (
304
  "You are an expert in extracting filters from property-related queries. Your task is to extract and return only the keys explicitly mentioned in the query as a valid JSON object (starting with '{' and ending with '}'). Include only those keys that are directly present in the query.\n\n"
305
  "The possible keys are:\n"
 
344
  {"role": "system", "content": system},
345
  {"role": "user", "content": human_str},
346
  ]
347
+ response = llm_no_stream.invoke(messages=filter_prompt)
348
  response_text = response.content if isinstance(response, AIMessage) else str(response)
349
  try:
350
  model_filters = extract_json_from_response(response_text)
 
359
 
360
 
361
  @tool
362
+ @lru_cache(maxsize=50,typed=False)
363
  def determine_route(query: str) -> dict:
364
  """Determine the route (search, suggest, detail, general, out_of_domain) for the query."""
365
  real_estate_keywords = estateKeywords
366
  pattern = re.compile("|".join(re.escape(keyword) for keyword in real_estate_keywords), re.IGNORECASE)
367
  positive_signal = bool(pattern.search(query))
368
 
369
+ # llm_local = ChatQwen(temperature=0.3, streaming=False, use_server=False, model_path=model_path)
370
  transform_suggest_to_list = query.lower().replace("suggest ", "list ", -1)
371
  system = """
372
  Classify the user query as:
 
388
  {"role": "user", "content": human_str},
389
  ]
390
 
391
+ response = llm_no_stream.invoke(messages=router_prompt)
392
  response_text = response.content if isinstance(response, AIMessage) else str(response)
393
  route_value = str(response_text).strip().lower()
394
 
 
496
  messages.append({"role": "system", "content": "When responding, use only the provided property details."})
497
 
498
  # Add conversation history
499
+ # Truncate conversation history (last 6 exchanges)
500
+ truncated_history = state.get("messages", [])[-12:] # Last 6 user+assistant pairs
501
  for msg in truncated_history:
502
  messages.append({"role": msg["role"], "content": msg["content"]})
503
 
 
548
  new_state["current_properties"] = state["current_properties"]
549
 
550
 
 
 
 
551
  if state.get("route") in ["search", "suggest"] and new_state.get("current_properties"):
552
  formatted = structured_property_data(state=new_state)
 
 
 
 
 
 
 
 
 
 
553
  aggregated_response = "Here are the property details:\n" + "\n".join(formatted)
 
554
 
555
  connection_id = state.get("connection_id")
556
  loop = state.get("loop")
 
799
  del conversation_managers[connection_id]
800
  manager_socket.disconnect(connection_id)
801
 
802
+
803
+
804
  @app.post("/query")
805
  async def post_query(query: str):
806
  conv_manager = ConversationManager()
 
835
  @app.get("/")
836
  async def home():
837
  return PlainTextResponse("Space is running. Model ready!")
838
+
839
+
840
+ # async def clear_cache_periodically(seconds: int = 3600):
841
+ # while True:
842
+ # await asyncio.sleep(seconds)
843
+ # extract_filters.cache_clear()
844
+ # determine_route.cache_clear()
845
+ # ChatQwen.build_prompt.cache_clear()
846
+ # print("Cache cleared")
847
+
848
+ # @app.on_event("startup")
849
+ # async def startup_event():
850
+ # background_tasks = BackgroundTasks()
851
+ # background_tasks.add_task(clear_cache_periodically, 3600) # Clear every hour
backup/app-backup.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import threading
3
+ import asyncio
4
+ import json
5
+ import re
6
+ import random
7
+ import time
8
+ import pickle
9
+ import numpy as np
10
+ import requests # For llama.cpp server calls
11
+ from datetime import datetime
12
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
13
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
14
+ from langgraph.graph import StateGraph, START, END
15
+ import faiss
16
+ from sentence_transformers import SentenceTransformer
17
+ from tools import extract_json_from_response, apply_filters_partial, rule_based_extract, structured_property_data, estateKeywords, sendTokenViaSocket
18
+ from langchain_core.prompts import ChatPromptTemplate
19
+ from langchain_core.tools import tool
20
+ from langchain_core.callbacks import StreamingStdOutCallbackHandler, CallbackManager
21
+ from langchain_core.callbacks.base import BaseCallbackHandler
22
+
23
+ import os
24
+ from fastapi.responses import PlainTextResponse
25
+ from fastapi import FastAPI, Request
26
+ from fastapi.staticfiles import StaticFiles
27
+ # ------------------------ Model Inference Wrapper ------------------------
28
+
29
+ class ChatQwen:
30
+ """
31
+ A chat wrapper for Qwen using llama.cpp.
32
+ This class can work in two modes:
33
+ - Local: Using a llama-cpp-python binding (gguf model file loaded locally).
34
+ - Server: Calling a remote llama.cpp server endpoint.
35
+ """
36
+ def __init__(
37
+ self,
38
+ temperature=0.3,
39
+ streaming=False,
40
+ max_new_tokens=512,
41
+ callbacks=None,
42
+ use_server=False,
43
+ model_path: str = None,
44
+ server_url: str = None
45
+ ):
46
+ self.temperature = temperature
47
+ self.streaming = streaming
48
+ self.max_new_tokens = max_new_tokens
49
+ self.callbacks = callbacks
50
+ self.use_server = use_server
51
+
52
+ if self.use_server:
53
+ # Use remote llama.cpp server – provide its URL.
54
+ self.server_url = server_url or "http://localhost:8000"
55
+ else:
56
+ # For local inference, a model_path must be provided.
57
+ if not model_path:
58
+ raise ValueError("Local mode requires a valid model_path to the gguf file.")
59
+ from llama_cpp import Llama # assumes llama-cpp-python is installed
60
+ self.model = Llama(
61
+ model_path=model_path,
62
+ temperature=self.temperature,
63
+ # n_ctx=512,
64
+ n_ctx=8192,
65
+ n_threads=4, # Adjust as needed
66
+ batch_size=512,
67
+ verbose=False,
68
+ )
69
+
70
+ def build_prompt(self, messages: list) -> str:
71
+ """Build Qwen-compatible prompt with special tokens."""
72
+ prompt = ""
73
+ for msg in messages:
74
+ role = msg["role"]
75
+ content = msg["content"]
76
+ if role == "system":
77
+ prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
78
+ elif role == "user":
79
+ prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
80
+ elif role == "assistant":
81
+ prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
82
+ prompt += "<|im_start|>assistant\n"
83
+ return prompt
84
+
85
+ def generate_text(self, messages: list) -> str:
86
+ prompt = self.build_prompt(messages)
87
+ stop_tokens = ["<|im_end|>", "\n"] # Qwen's stop sequences
88
+
89
+ if self.use_server:
90
+ payload = {
91
+ "prompt": prompt,
92
+ "max_tokens": self.max_new_tokens,
93
+ "temperature": self.temperature,
94
+ "stream": self.streaming,
95
+ "stop": stop_tokens # Add stop tokens to server request
96
+ }
97
+ if self.streaming:
98
+ response = requests.post(f"{self.server_url}/generate", json=payload, stream=True)
99
+ generated_text = ""
100
+ for line in response.iter_lines():
101
+ if line:
102
+ token = line.decode("utf-8")
103
+ # Check for stop tokens in stream
104
+ if any(stop in token for stop in stop_tokens):
105
+ break
106
+ generated_text += token
107
+ if self.callbacks:
108
+ for callback in self.callbacks:
109
+ callback.on_llm_new_token(token)
110
+ return generated_text
111
+ else:
112
+ response = requests.post(f"{self.server_url}/generate", json=payload)
113
+ return response.json().get("generated_text", "")
114
+ else:
115
+ # Local llama.cpp inference
116
+ if self.streaming:
117
+ stream = self.model.create_completion(
118
+ prompt=prompt,
119
+ max_tokens=self.max_new_tokens,
120
+ temperature=self.temperature,
121
+ stream=True,
122
+ stop=stop_tokens
123
+ )
124
+ generated_text = ""
125
+ for token_chunk in stream:
126
+ token_text = token_chunk["choices"][0]["text"]
127
+ # Stop early if we detect end token
128
+ if any(stop in token_text for stop in stop_tokens):
129
+ break
130
+ generated_text += token_text
131
+ if self.callbacks:
132
+ for callback in self.callbacks:
133
+ callback.on_llm_new_token(token_text)
134
+ return generated_text
135
+ else:
136
+ result = self.model.create_completion(
137
+ prompt=prompt,
138
+ max_tokens=self.max_new_tokens,
139
+ temperature=self.temperature,
140
+ stop=stop_tokens
141
+ )
142
+ return result["choices"][0]["text"]
143
+
144
+ def invoke(self, messages: list, config: dict = None) -> AIMessage:
145
+ config = config or {}
146
+ callbacks = config.get("callbacks", self.callbacks)
147
+ original_callbacks = self.callbacks
148
+ self.callbacks = callbacks
149
+
150
+ output_text = self.generate_text(messages)
151
+ self.callbacks = original_callbacks
152
+
153
+ # In streaming mode we return an empty content as tokens are being sent via callbacks.
154
+ if self.streaming:
155
+ return AIMessage(content="")
156
+ else:
157
+ return AIMessage(content=output_text)
158
+
159
+ def __call__(self, messages: list) -> AIMessage:
160
+ return self.invoke(messages)
161
+
162
+ # ------------------------ Callback for WebSocket Streaming ------------------------
163
+
164
+ class WebSocketStreamingCallbackHandler(BaseCallbackHandler):
165
+ def __init__(self, connection_id: str, loop):
166
+ self.connection_id = connection_id
167
+ self.loop = loop
168
+
169
+ def on_llm_new_token(self, token: str, **kwargs):
170
+ asyncio.run_coroutine_threadsafe(
171
+ manager_socket.send_message(self.connection_id, token),
172
+ self.loop
173
+ )
174
+
175
+ # ------------------------ Instantiate the LLM ------------------------
176
+ # Choose one mode: local (set use_server=False) or server (set use_server=True).
177
+ model_path="qwen2.5-1.5b-instruct-q4_k_m.gguf"
178
+ llm = ChatQwen(
179
+ temperature=0.3,
180
+ streaming=True,
181
+ max_new_tokens=512,
182
+ use_server=False,
183
+ model_path=model_path,
184
+ # server_url="http://localhost:8000" # Uncomment and set if using server mode.
185
+ )
186
+
187
+ # ------------------------ FAISS and Sentence Transformer Setup ------------------------
188
+
189
+ index = faiss.read_index("./faiss.index")
190
+ with open("./metadata.pkl", "rb") as f:
191
+ docs = pickle.load(f)
192
+ st_model = SentenceTransformer('all-MiniLM-L6-v2')
193
+
194
+ def make_system_prompt(suffix: str) -> str:
195
+ return (
196
+ "You are EstateGuru, a real estate expert developed by Abhishek Pathak at SwavishTek. "
197
+ "Your role is to help customers buy properties using only the provided dataβ€”do not invent any details. "
198
+ "The default currency is AED; if a query mentions another currency, convert the amount to AED "
199
+ "(for example, convert $10k to 36726.50 AED and $1 to 3.67 AED). "
200
+ "If a customer is interested in a property or needs to contact an agent, instruct them to call +91 8766268285. "
201
+ "Keep your answers short, clear, and concise."
202
+ f"\n{suffix}"
203
+ )
204
+
205
+ general_query_prompt = make_system_prompt(
206
+ "You are EstateGuru, a helpful real estate assistant. "
207
+ "Please respond only in English. "
208
+ "Convert any prices to USD before answering. "
209
+ "Provide a brief, direct answer without extra details."
210
+ )
211
+
212
+ # ------------------------ Tool Definitions ------------------------
213
+
214
+ @tool
215
+ def extract_filters(query: str) -> dict:
216
+ """Extract filters from the query."""
217
+ llm_local = ChatQwen(temperature=0.3, streaming=False, use_server=False, model_path=model_path)
218
+ system = (
219
+ "You are an expert in extracting filters from property-related queries. Your task is to extract and return only the keys explicitly mentioned in the query as a valid JSON object (starting with '{' and ending with '}'). Include only those keys that are directly present in the query.\n\n"
220
+ "The possible keys are:\n"
221
+ " - 'projectName': The name of the project.\n"
222
+ " - 'developerName': The developer's name.\n"
223
+ " - 'relationshipManager': The relationship manager.\n"
224
+ " - 'propertyAddress': The property address.\n"
225
+ " - 'surroundingArea': The area or nearby landmarks.\n"
226
+ " - 'propertyType': The type or configuration of the property.\n"
227
+ " - 'amenities': Any amenities mentioned.\n"
228
+ " - 'coveredParking': Parking availability.\n"
229
+ " - 'petRules': Pet policies.\n"
230
+ " - 'security': Security details.\n"
231
+ " - 'occupancyRate': Occupancy information.\n"
232
+ " - 'constructionImpact': Construction or its impact.\n"
233
+ " - 'propertySize': Size of the property.\n"
234
+ " - 'propertyView': View details.\n"
235
+ " - 'propertyCondition': Condition of the property.\n"
236
+ " - 'serviceCharges': Service or maintenance charges.\n"
237
+ " - 'ownershipType': Ownership type.\n"
238
+ " - 'totalCosts': A cost threshold or cost amount.\n"
239
+ " - 'paymentPlans': Payment or financing plans.\n"
240
+ " - 'expectedRentalYield': Expected rental yield.\n"
241
+ " - 'rentalHistory': Rental history.\n"
242
+ " - 'shortTermRentals': Short-term rental information.\n"
243
+ " - 'resalePotential': Resale potential.\n"
244
+ " - 'uniqueId': A unique identifier.\n\n"
245
+ "Important instructions regarding cost thresholds:\n"
246
+ " - If the query contains phrases like 'under 10k', 'below 2m', or 'less than 5k', interpret these as cost thresholds.\n"
247
+ " - Convert any shorthand cost values to pure numbers (for example, '10k' becomes 10000, '2m' becomes 2000000) and assign them to the key 'totalCosts'.\n"
248
+ " - Do not use 'propertySize' for cost thresholds.\n\n"
249
+ " - Default currency is AED, if user query have different currency symbol then convert to equivalent AED amount (eg. $10k becomes 36726.50, $1 becomes 3.67).\n\n"
250
+ "Example:\n"
251
+ " For the query: \"properties near dubai mall under 43k\"\n"
252
+ " The expected output should be:\n"
253
+ " { \"surroundingArea\": \"dubai mall\", \"totalCosts\": 43000 }\n\n"
254
+ "Return ONLY a valid JSON object with the extracted keys and their corresponding values, with no additional text."
255
+ )
256
+
257
+ human_str = f"Here is the query:\n{query}"
258
+ filter_prompt = [
259
+ {"role": "system", "content": system},
260
+ {"role": "user", "content": human_str},
261
+ ]
262
+ response = llm_local.invoke(messages=filter_prompt)
263
+ response_text = response.content if isinstance(response, AIMessage) else str(response)
264
+ try:
265
+ model_filters = extract_json_from_response(response_text)
266
+ except Exception as e:
267
+ print(f"JSON parsing error: {e}")
268
+ model_filters = {}
269
+ rule_filters = rule_based_extract(query)
270
+ print("Rule-based extraction:", rule_filters)
271
+ final_filters = {**model_filters, **rule_filters}
272
+ print("Final extraction:", final_filters)
273
+ return {"filters": final_filters}
274
+
275
+
276
+ @tool
277
+ def determine_route(query: str) -> dict:
278
+ """Determine the route (search, suggest, detail, general, out_of_domain) for the query."""
279
+ real_estate_keywords = estateKeywords
280
+ pattern = re.compile("|".join(re.escape(keyword) for keyword in real_estate_keywords), re.IGNORECASE)
281
+ positive_signal = bool(pattern.search(query))
282
+
283
+ llm_local = ChatQwen(temperature=0.3, streaming=False, use_server=False, model_path=model_path)
284
+ transform_suggest_to_list = query.lower().replace("suggest ", "list ", -1)
285
+ system = """
286
+ Classify the user query as:
287
+
288
+ - **"search"**: if it requests property listings with specific filters (e.g., location, price, property type like "2bhk", service charges, pet policies, etc.).
289
+ - **"suggest"**: if it asks for property suggestions without filters.
290
+ - **"detail"**: if it is asking for more information about a previously provided property (for example, "tell me more about property 5" or "I want more information regarding 4BHK").
291
+ - **"general"**: for all other real estate-related questions.
292
+ - **"out_of_domain"**: if the query is not related to real estate (for example, tourist attractions, restaurants, etc.).
293
+
294
+ Keep in mind that queries mentioning terms like "service charge", "allow pets", "pet rules", etc., are considered real estate queries.
295
+ When user asks about you (for example, "who you are", "who made you" etc.) consider as general.
296
+
297
+ Return only the keyword: search, suggest, detail, general, or out_of_domain.
298
+ """
299
+ human_str = f"Here is the query:\n{transform_suggest_to_list}"
300
+ router_prompt = [
301
+ {"role": "system", "content": system},
302
+ {"role": "user", "content": human_str},
303
+ ]
304
+
305
+ response = llm_local.invoke(messages=router_prompt)
306
+ response_text = response.content if isinstance(response, AIMessage) else str(response)
307
+ route_value = str(response_text).strip().lower()
308
+
309
+ # --- NEW: Force 'detail' if query explicitly mentions a specific property (e.g., "property 2") ---
310
+ property_detail_pattern = re.compile(r"property\s+\d+", re.IGNORECASE)
311
+ if property_detail_pattern.search(query):
312
+ route_value = "detail"
313
+
314
+ # Fallback override if query appears detailed.
315
+ detail_phrases = [
316
+ "more information", "tell me more", "more details", "give me more details",
317
+ "i need more details", "can you provide more details", "additional details",
318
+ "further information", "expand on that", "explain further", "elaborate more",
319
+ "more specifics", "i want to know more", "could you elaborate", "need more info",
320
+ "provide more details", "detail it further", "in-depth information", "break it down further",
321
+ "further explanation", "property 1", "property1", "first property", "about the 2nd", "regarding number 3"
322
+ ]
323
+ if any(phrase in query.lower() for phrase in detail_phrases):
324
+ route_value = "detail"
325
+
326
+ if route_value not in {"search", "suggest", "detail", "general", "out_of_domain"}:
327
+ route_value = "general"
328
+ if route_value == "out_of_domain" and positive_signal:
329
+ route_value = "general"
330
+ if route_value == "out_of_domain":
331
+ route_value = "general" if positive_signal else "out_of_domain"
332
+
333
+ return {"route": route_value}
334
+
335
+ # ------------------------ Workflow Setup ------------------------
336
+
337
+ workflow = StateGraph(state_schema=dict)
338
+
339
+ def route_query(state: dict) -> dict:
340
+ new_state = state.copy()
341
+ try:
342
+ new_state["route"] = determine_route.invoke(new_state.get("query", "")).get("route", "general")
343
+ print(new_state["route"])
344
+ except Exception as e:
345
+ print(f"Routing error: {e}")
346
+ new_state["route"] = "general"
347
+ return new_state
348
+
349
+ def hybrid_extract(state: dict) -> dict:
350
+ new_state = state.copy()
351
+ new_state["filters"] = extract_filters.invoke(new_state.get("query", "")).get("filters", {})
352
+ return new_state
353
+
354
+ def search_faiss(state: dict) -> dict:
355
+ new_state = state.copy()
356
+ # Preserve previous properties until new ones are fetched:
357
+ new_state.setdefault("current_properties", state.get("current_properties", []))
358
+ query_embedding = st_model.encode([state["query"]])
359
+ _, indices = index.search(query_embedding.astype(np.float32), 5)
360
+ new_state["faiss_results"] = [docs[idx] for idx in indices[0] if idx < len(docs)]
361
+ return new_state
362
+
363
+ def apply_filters(state: dict) -> dict:
364
+ new_state = state.copy()
365
+ new_state["final_results"] = apply_filters_partial(state["faiss_results"], state.get("filters", {}))
366
+ if(len(new_state["final_results"]) == 0):
367
+ new_state["response"] = "Sorry, There is no result found :("
368
+ new_state["route"] = "general"
369
+ return new_state
370
+
371
+ def suggest_properties(state: dict) -> dict:
372
+ new_state = state.copy()
373
+ new_state["suggestions"] = random.sample(docs, 5)
374
+ # Explicitly update current_properties only when new listings are fetched
375
+ new_state["current_properties"] = new_state["suggestions"]
376
+ if(len(new_state["suggestions"]) == 0):
377
+ new_state["response"] = "Sorry, There is no result found :("
378
+ new_state["route"] = "general"
379
+ return new_state
380
+
381
+ def handle_out_of_domain(state: dict) -> dict:
382
+ new_state = state.copy()
383
+ new_state["response"] = "I only handle real estate inquiries. Please ask a question related to properties."
384
+ return new_state
385
+
386
+
387
+
388
+ def generate_response(state: dict) -> dict:
389
+ new_state = state.copy()
390
+ messages = []
391
+
392
+ # Add the general query prompt.
393
+ messages.append({"role": "system", "content": general_query_prompt})
394
+
395
+ # For detail queries (specific property queries), add extra instructions.
396
+ if new_state.get("route", "general") == "detail":
397
+ messages.append({
398
+ "role": "system",
399
+ "content": (
400
+ "The user is asking about a specific property from the numbered list below. "
401
+ "Properties are listed as 1, 2, 3, etc. Use ONLY the corresponding property details. "
402
+ "For example, if the user says 'property 2', respond using only the details from the second entry. Never invent data."
403
+ )
404
+ })
405
+
406
+ if new_state.get("current_properties"):
407
+ # Format properties with indices starting at 1
408
+ property_context = format_property_data_with_indices(new_state["current_properties"])
409
+ messages.append({"role": "system", "content": "Available Properties:\n" + property_context})
410
+ messages.append({"role": "system", "content": "When responding, use only the provided property details."})
411
+
412
+ # Add conversation history
413
+ # Truncate conversation history (last 6 exchanges)
414
+ truncated_history = state.get("messages", [])[-12:] # Last 6 user+assistant pairs
415
+ for msg in truncated_history:
416
+ messages.append({"role": msg["role"], "content": msg["content"]})
417
+
418
+ connection_id = state.get("connection_id")
419
+ loop = state.get("loop")
420
+ if connection_id and loop:
421
+ print("Using WebSocket streaming")
422
+ callback_manager = [WebSocketStreamingCallbackHandler(connection_id, loop)]
423
+ _ = llm.invoke(
424
+ messages,
425
+ config={"callbacks": callback_manager}
426
+ )
427
+ new_state["response"] = ""
428
+ else:
429
+ callback_manager = [StreamingStdOutCallbackHandler()]
430
+ response = llm.invoke(
431
+ messages,
432
+ config={"callbacks": callback_manager}
433
+ )
434
+ new_state["response"] = response.content if isinstance(response, AIMessage) else str(response)
435
+
436
+ return new_state
437
+
438
+
439
+ def format_property_data_with_indices(properties: list) -> str:
440
+ formatted = []
441
+ for idx, prop in enumerate(properties, 1):
442
+ cost = prop.get("totalCosts", "N/A")
443
+ cost_str = f"{cost:,}" if isinstance(cost, (int, float)) else cost
444
+ formatted.append(
445
+ f"{idx}. Type: {prop['propertyType']}, Cost: AED {cost_str}, "
446
+ f"Size: {prop.get('propertySize', 'N/A')}, Amenities: {', '.join(prop.get('amenities', []))}, "
447
+ f"Rental Yield: {prop.get('expectedRentalYield', 'N/A')}, "
448
+ f"Ownership: {prop.get('ownershipType', 'N/A')}"
449
+ )
450
+ return "\n".join(formatted)
451
+
452
+
453
+ def format_final_response(state: dict) -> dict:
454
+ new_state = state.copy()
455
+
456
+ if state.get("route") in ["search", "suggest"]:
457
+ if "final_results" in state:
458
+ new_state["current_properties"] = state["final_results"]
459
+ elif "suggestions" in state:
460
+ new_state["current_properties"] = state["suggestions"]
461
+ elif "current_properties" in new_state:
462
+ new_state["current_properties"] = state["current_properties"]
463
+
464
+
465
+ # print("state: ", json.dumps(new_state), "\n\n")
466
+ # Format the property details if available.
467
+ # if new_state.get("current_properties"):
468
+ if state.get("route") in ["search", "suggest"] and new_state.get("current_properties"):
469
+ formatted = structured_property_data(state=new_state)
470
+
471
+ # for idx, prop in enumerate(new_state["current_properties"], 1):
472
+ # cost = prop.get("totalCosts", "N/A")
473
+ # cost_str = f"{cost:,}" if isinstance(cost, (int, float)) else cost
474
+ # formatted.append(
475
+ # f"{idx}. Type: {prop['propertyType']}, Cost: AED {cost_str}, "
476
+ # f"Size: {prop.get('propertySize', 'N/A')}, Amenities: {', '.join(map(str, prop.get('amenities', []))) if prop.get('amenities') else 'N/A'}, "
477
+ # f"Rental Yield: {prop.get('expectedRentalYield', 'N/A')}, "
478
+ # f"Ownership: {prop.get('ownershipType', 'N/A')}\n"
479
+ # )
480
+ aggregated_response = "Here are the property details:\n" + "\n".join(formatted)
481
+ # print(aggregated_response)
482
+
483
+ connection_id = state.get("connection_id")
484
+ loop = state.get("loop")
485
+ if connection_id and loop:
486
+ import time
487
+ tokens = aggregated_response.split(" ")
488
+ for token in tokens:
489
+ asyncio.run_coroutine_threadsafe(
490
+ manager_socket.send_message(connection_id, token + " "),
491
+ loop
492
+ )
493
+ time.sleep(0.05)
494
+ new_state["response"] = ""
495
+ else:
496
+ new_state["response"] = aggregated_response
497
+ elif "response" in new_state:
498
+ connection_id = state.get("connection_id")
499
+ loop = state.get("loop")
500
+ if connection_id and loop:
501
+ import time
502
+ tokens = str(new_state["response"]).split(" ")
503
+ for token in tokens:
504
+ asyncio.run_coroutine_threadsafe(
505
+ manager_socket.send_message(connection_id, token + " "),
506
+ loop
507
+ )
508
+ time.sleep(0.05)
509
+ new_state["response"] = str(new_state["response"])
510
+
511
+ return new_state
512
+
513
+
514
+
515
+ nodes = [
516
+ ("route_query", route_query),
517
+ ("hybrid_extract", hybrid_extract),
518
+ ("faiss_search", search_faiss),
519
+ ("apply_filters", apply_filters),
520
+ ("suggest_properties", suggest_properties),
521
+ ("handle_out_of_domain", handle_out_of_domain),
522
+ ("generate_response", generate_response),
523
+ ("format_response", format_final_response)
524
+ ]
525
+
526
+ for name, node in nodes:
527
+ workflow.add_node(name, node)
528
+
529
+ workflow.add_edge(START, "route_query")
530
+ workflow.add_conditional_edges(
531
+ "route_query",
532
+ lambda state: state.get("route", "general"),
533
+ {
534
+ "search": "hybrid_extract",
535
+ "suggest": "suggest_properties",
536
+ "detail": "generate_response",
537
+ "general": "generate_response",
538
+ "out_of_domain": "handle_out_of_domain"
539
+ }
540
+ )
541
+ workflow.add_edge("hybrid_extract", "faiss_search")
542
+ workflow.add_edge("faiss_search", "apply_filters")
543
+ workflow.add_edge("apply_filters", "format_response")
544
+ workflow.add_edge("suggest_properties", "format_response")
545
+ workflow.add_edge("generate_response", "format_response")
546
+ workflow.add_edge("handle_out_of_domain", "format_response")
547
+ workflow.add_edge("format_response", END)
548
+
549
+ workflow_app = workflow.compile()
550
+
551
+ # ------------------------ Conversation Manager ------------------------
552
+
553
+ class ConversationManager:
554
+ def __init__(self):
555
+ # Each connection gets its own conversation history and state.
556
+ self.conversation_history = []
557
+ # current_properties stores the current property listing.
558
+ self.current_properties = []
559
+
560
+ def _add_message(self, role: str, content: str):
561
+ self.conversation_history.append({
562
+ "role": role,
563
+ "content": content,
564
+ "timestamp": datetime.now().isoformat()
565
+ })
566
+
567
+ def process_query(self, query: str) -> str:
568
+ # For greeting messages, reset history/state. // post request
569
+ if query.strip().lower() in {"hi", "hello", "hey"}:
570
+ self.conversation_history = []
571
+ self.current_properties = []
572
+ greeting_response = "Hello! How can I assist you today with your real estate inquiries?"
573
+ self._add_message("assistant", greeting_response)
574
+ return greeting_response
575
+
576
+ try:
577
+ self._add_message("user", query)
578
+ initial_state = {
579
+ "messages": self.conversation_history.copy(),
580
+ "query": query,
581
+ "route": "general",
582
+ "filters": {},
583
+ "current_properties": self.current_properties
584
+ }
585
+ for event in workflow_app.stream(initial_state, stream_mode="values"):
586
+ final_state = event
587
+ # Only update property listings if a new listing is fetched
588
+ # if 'final_results' in final_state:
589
+ # self.current_properties = final_state['final_results']
590
+ # elif 'suggestions' in final_state:
591
+ # self.current_properties = final_state['suggestions']
592
+ self.current_properties = final_state.get("current_properties", [])
593
+
594
+ if final_state.get("route") == "general":
595
+ response_text = final_state.get("response", "")
596
+ self._add_message("assistant", response_text)
597
+ return response_text
598
+ else:
599
+ response = final_state.get("response", "I couldn't process that request.")
600
+ self._add_message("assistant", response)
601
+ return response
602
+ except Exception as e:
603
+ print(f"Processing error: {e}")
604
+ return "Sorry, I encountered an error processing your request."
605
+
606
+
607
+
608
+ conversation_managers = {}
609
+
610
+ # ------------------------ FastAPI Backend with WebSockets ------------------------
611
+
612
+ app = FastAPI()
613
+
614
+ class ConnectionManager:
615
+ def __init__(self):
616
+ self.active_connections = {}
617
+
618
+ async def connect(self, websocket: WebSocket):
619
+ await websocket.accept()
620
+ connection_id = str(uuid.uuid4())
621
+ self.active_connections[connection_id] = websocket
622
+ print(f"New connection: {connection_id}")
623
+ return connection_id
624
+
625
+ def disconnect(self, connection_id: str):
626
+ if connection_id in self.active_connections:
627
+ del self.active_connections[connection_id]
628
+ print(f"Disconnected: {connection_id}")
629
+
630
+ async def send_message(self, connection_id: str, message: str):
631
+ websocket = self.active_connections.get(connection_id)
632
+ if websocket:
633
+ await websocket.send_text(message)
634
+
635
+ manager_socket = ConnectionManager()
636
+
637
+ def stream_query(query: str, connection_id: str, loop):
638
+ conv_manager = conversation_managers.get(connection_id)
639
+ if conv_manager is None:
640
+ print(f"No conversation manager found for connection {connection_id}")
641
+ return
642
+
643
+ if query.strip().lower() in {"hi", "hello", "hey"}:
644
+ conv_manager.conversation_history = []
645
+ conv_manager.current_properties = []
646
+ greeting_response = "Hello! How can I assist you today with your real estate inquiries?"
647
+ conv_manager._add_message("assistant", greeting_response)
648
+ sendTokenViaSocket(
649
+ state={"connection_id": connection_id, "loop": loop},
650
+ manager_socket=manager_socket,
651
+ message=greeting_response
652
+ )
653
+ # asyncio.run_coroutine_threadsafe(
654
+ # manager_socket.send_message(connection_id, greeting_response),
655
+ # loop
656
+ # )
657
+ return
658
+
659
+ conv_manager._add_message("user", query)
660
+ initial_state = {
661
+ "messages": conv_manager.conversation_history.copy(),
662
+ "query": query,
663
+ "route": "general",
664
+ "filters": {},
665
+ "current_properties": conv_manager.current_properties,
666
+ "connection_id": connection_id,
667
+ "loop": loop
668
+ }
669
+ # try:
670
+ # workflow_app.invoke(initial_state)
671
+ # except Exception as e:
672
+ # error_msg = f"Error processing query: {str(e)}"
673
+ # asyncio.run_coroutine_threadsafe(
674
+ # manager_socket.send_message(connection_id, error_msg),
675
+ # loop
676
+ # )
677
+ try:
678
+ # Capture all states during execution
679
+ # final_state = None
680
+ # for event in workflow_app.stream(initial_state, stream_mode="values"):
681
+ # final_state = event
682
+
683
+ # # Update conversation manager with final state
684
+ # if final_state:
685
+ # conv_manager.current_properties = final_state.get("current_properties", [])
686
+ # if final_state.get("response"):
687
+ # conv_manager._add_message("assistant", final_state["response"])
688
+ final_state = None
689
+ for event in workflow_app.stream(initial_state, stream_mode="values"):
690
+ final_state = event
691
+
692
+ if final_state:
693
+ # Always update current_properties from final state
694
+ conv_manager.current_properties = final_state.get("current_properties", [])
695
+ # Keep conversation history bounded
696
+ conv_manager.conversation_history = conv_manager.conversation_history[-12:] # Last 6 exchanges
697
+
698
+ except Exception as e:
699
+ error_msg = f"Error processing query: {str(e)}"
700
+ asyncio.run_coroutine_threadsafe(
701
+ manager_socket.send_message(connection_id, error_msg),
702
+ loop
703
+ )
704
+
705
+
706
+
707
+ @app.websocket("/ws")
708
+ async def websocket_endpoint(websocket: WebSocket):
709
+ connection_id = await manager_socket.connect(websocket)
710
+ # Each connection maintains its own conversation manager.
711
+ conversation_managers[connection_id] = ConversationManager()
712
+ try:
713
+ while True:
714
+ query = await websocket.receive_text()
715
+ loop = asyncio.get_event_loop()
716
+ threading.Thread(
717
+ target=stream_query,
718
+ args=(query, connection_id, loop),
719
+ daemon=True
720
+ ).start()
721
+ except WebSocketDisconnect:
722
+ conv_manager = conversation_managers.get(connection_id)
723
+ if conv_manager:
724
+ filename = f"conversations/conversation_{connection_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
725
+ with open(filename, "w") as f:
726
+ json.dump(conv_manager.conversation_history, f, indent=4)
727
+ del conversation_managers[connection_id]
728
+ manager_socket.disconnect(connection_id)
729
+
730
+
731
+
732
+ @app.post("/query")
733
+ async def post_query(query: str):
734
+ conv_manager = ConversationManager()
735
+ response = conv_manager.process_query(query)
736
+ return {"response": response}
737
+
738
+
739
+
740
+
741
+ model_url = "https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct-GGUF/resolve/main/qwen2.5-1.5b-instruct-q4_k_m.gguf"
742
+ async def async_download():
743
+ import aiohttp
744
+ async with aiohttp.ClientSession() as session:
745
+ async with session.get(model_url) as response:
746
+ with open(model_path, "wb") as f:
747
+ while True:
748
+ chunk = await response.content.read(1024)
749
+ if not chunk:
750
+ break
751
+ f.write(chunk)
752
+
753
+ @app.middleware("http")
754
+ async def check_model_middleware(request: Request, call_next):
755
+ if not os.path.exists(model_path):
756
+ await async_download()
757
+ print("successfully downloaded")
758
+ else:
759
+ print("already downloaded")
760
+ return await call_next(request)
761
+
762
+
763
+ @app.get("/")
764
+ async def home():
765
+ return PlainTextResponse("Space is running. Model ready!")
test.py β†’ backup/test.py RENAMED
File without changes
test2.py β†’ backup/test2.py RENAMED
File without changes
test3.py β†’ backup/test3.py RENAMED
File without changes
index.html CHANGED
@@ -49,6 +49,7 @@
49
  <script>
50
  // Create a WebSocket connection to your backend
51
  const ws = new WebSocket("ws://localhost:8000/ws");
 
52
 
53
  // This variable holds the current assistant message element for live updating.
54
  let currentAssistantMessageEl = null;
 
49
  <script>
50
  // Create a WebSocket connection to your backend
51
  const ws = new WebSocket("ws://localhost:8000/ws");
52
+ // const ws = new WebSocket("wss://pathakdev10-estateguru.hf.space/ws");
53
 
54
  // This variable holds the current assistant message element for live updating.
55
  let currentAssistantMessageEl = null;