Spaces:
Sleeping
Sleeping
import uuid | |
import threading | |
import asyncio | |
import json | |
import re | |
from datetime import datetime | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
from langgraph.graph import StateGraph, START, END | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
import pickle | |
import numpy as np | |
from tools import extract_json_from_response, apply_filters_partial, rule_based_extract, format_property_data, estateKeywords | |
import random | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.tools import tool | |
from langchain_core.callbacks import StreamingStdOutCallbackHandler, CallbackManager | |
from langchain_core.callbacks.base import BaseCallbackHandler | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer | |
class CallbackTextStreamer(TextStreamer): | |
def __init__(self, tokenizer, callbacks, skip_prompt=True, skip_special_tokens=True): | |
super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens) | |
self.callbacks = callbacks | |
def on_new_token(self, token: str): | |
for callback in self.callbacks: | |
callback.on_llm_new_token(token) | |
class ChatQwen: | |
def __init__(self, temperature=0.3, streaming=False, max_new_tokens=512, callbacks=None): | |
self.temperature = temperature | |
self.streaming = streaming | |
self.max_new_tokens = max_new_tokens | |
self.callbacks = callbacks | |
self.model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
def generate_text(self, messages: list) -> str: | |
""" | |
Given a list of messages, create a prompt and generate text using the Qwen model. | |
In streaming mode, uses a TextIteratorStreamer and iterates over tokens to call callbacks. | |
""" | |
# Create prompt from messages using the tokenizer's chat template. | |
prompt = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device) | |
if self.streaming: | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
# Create the streamer that collects tokens as they are generated. | |
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict( | |
**model_inputs, | |
max_new_tokens=self.max_new_tokens, | |
streamer=streamer, | |
temperature=self.temperature, | |
do_sample=True | |
) | |
# Run generation in a separate thread so that we can iterate over tokens. | |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
# Iterate over tokens as they arrive. | |
for token in streamer: | |
generated_text += token | |
# Call each callback with the new token. | |
if self.callbacks: | |
for callback in self.callbacks: | |
callback.on_llm_new_token(token) | |
# In streaming mode you may want to return empty string, | |
# but here we return the full text if needed. | |
return generated_text | |
else: | |
outputs = self.model.generate( | |
**model_inputs, | |
max_new_tokens=self.max_new_tokens, | |
temperature=self.temperature, | |
do_sample=True | |
) | |
# Remove the prompt tokens from the output. | |
prompt_length = model_inputs.input_ids.shape[-1] | |
generated_ids = outputs[0][prompt_length:] | |
text_output = self.tokenizer.decode(generated_ids, skip_special_tokens=True) | |
return text_output | |
def invoke(self, messages: list, config: dict = None) -> AIMessage: | |
config = config or {} | |
# Use provided callbacks if any, otherwise default to the callbacks in the instance. | |
callbacks = config.get("callbacks", self.callbacks) | |
original_callbacks = self.callbacks | |
self.callbacks = callbacks | |
output_text = self.generate_text(messages) | |
self.callbacks = original_callbacks | |
if self.streaming: | |
return AIMessage(content="") | |
else: | |
return AIMessage(content=output_text) | |
def __call__(self, messages: list) -> AIMessage: | |
return self.invoke(messages) | |
class WebSocketStreamingCallbackHandler(BaseCallbackHandler): | |
def __init__(self, connection_id: str, loop): | |
self.connection_id = connection_id | |
self.loop = loop | |
def on_llm_new_token(self, token: str, **kwargs): | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(self.connection_id, token), | |
self.loop | |
) | |
llm = ChatQwen(temperature=0.3, streaming=True, max_new_tokens=512) | |
index = faiss.read_index("./faiss.index") | |
with open("./metadata.pkl", "rb") as f: | |
docs = pickle.load(f) | |
st_model = SentenceTransformer('all-MiniLM-L6-v2') | |
def make_system_prompt(suffix: str) -> str: | |
return ( | |
"You are EstateGuru, a real estate expert created by Abhishek Pathak from SwavishTek. " | |
"Your role is to help customers buy properties using the available data. " | |
"Only use the provided data—do not make up any information. " | |
"The default currency is AED. If a query uses a different currency, convert the amount to AED " | |
"(for example, $10k becomes 36726.50 AED and $1 becomes 3.67 AED). " | |
"If a customer is interested in a property, wants to buy, or needs to contact an agent or customer care, " | |
"instruct them to call +91 8766268285." | |
f"\n{suffix}" | |
) | |
general_query_prompt = make_system_prompt( | |
"You are EstateGuru, a helpful real estate assistant. Answer the user's query accurately using the available data. " | |
"Do not invent any details or go beyond the real estate domain. " | |
"If the user shows interest in a property or contacting an agent, ask them to call +91 8766268285." | |
) | |
# ------------------------ Tool Definitions ------------------------ | |
def extract_filters(query: str) -> dict: | |
"""For extracting filters""" | |
# Use a non-streaming ChatQwen for tool use. | |
llm_local = ChatQwen(temperature=0.3, streaming=False) | |
system = ( | |
"You are an expert in extracting filters from property-related queries. Your task is to extract and return only the keys explicitly mentioned in the query as a valid JSON object (starting with '{' and ending with '}'). Include only those keys that are directly present in the query.\n\n" | |
"The possible keys are:\n" | |
" - 'projectName': The name of the project.\n" | |
" - 'developerName': The developer's name.\n" | |
" - 'relationshipManager': The relationship manager.\n" | |
" - 'propertyAddress': The property address.\n" | |
" - 'surroundingArea': The area or nearby landmarks.\n" | |
" - 'propertyType': The type or configuration of the property.\n" | |
" - 'amenities': Any amenities mentioned.\n" | |
" - 'coveredParking': Parking availability.\n" | |
" - 'petRules': Pet policies.\n" | |
" - 'security': Security details.\n" | |
" - 'occupancyRate': Occupancy information.\n" | |
" - 'constructionImpact': Construction or its impact.\n" | |
" - 'propertySize': Size of the property.\n" | |
" - 'propertyView': View details.\n" | |
" - 'propertyCondition': Condition of the property.\n" | |
" - 'serviceCharges': Service or maintenance charges.\n" | |
" - 'ownershipType': Ownership type.\n" | |
" - 'totalCosts': A cost threshold or cost amount.\n" | |
" - 'paymentPlans': Payment or financing plans.\n" | |
" - 'expectedRentalYield': Expected rental yield.\n" | |
" - 'rentalHistory': Rental history.\n" | |
" - 'shortTermRentals': Short-term rental information.\n" | |
" - 'resalePotential': Resale potential.\n" | |
" - 'uniqueId': A unique identifier.\n\n" | |
"Important instructions regarding cost thresholds:\n" | |
" - If the query contains phrases like 'under 10k', 'below 2m', or 'less than 5k', interpret these as cost thresholds.\n" | |
" - Convert any shorthand cost values to pure numbers (for example, '10k' becomes 10000, '2m' becomes 2000000) and assign them to the key 'totalCosts'.\n" | |
" - Do not use 'propertySize' for cost thresholds.\n\n" | |
" - Default currency is AED, if user query have different currency symbol then convert to equivalent AED amount (eg. $10k becomes 36726.50, $1 becomes 3.67).\n\n" | |
"Example:\n" | |
" For the query: \"properties near dubai mall under 43k\"\n" | |
" The expected output should be:\n" | |
" { \"surroundingArea\": \"dubai mall\", \"totalCosts\": 43000 }\n\n" | |
"Return ONLY a valid JSON object with the extracted keys and their corresponding values, with no additional text." | |
) | |
human_str = f"Here is the query:\n{query}" | |
filter_prompt = [ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": human_str}, | |
] | |
response = llm_local.invoke(messages=filter_prompt) | |
response_text = response.content if isinstance(response, AIMessage) else str(response) | |
try: | |
model_filters = extract_json_from_response(response_text) | |
except Exception as e: | |
print(f"JSON parsing error: {e}") | |
model_filters = {} | |
rule_filters = rule_based_extract(query) | |
print("Rule-based extraction:", rule_filters) | |
final_filters = {**model_filters, **rule_filters} | |
print("Final extraction:", final_filters) | |
return {"filters": final_filters} | |
def determine_route(query: str) -> dict: | |
"""For determining route using enhanced prompt and fallback logic.""" | |
# Define a set of keywords that are strong indicators of a real estate query. | |
real_estate_keywords = estateKeywords | |
# Check if the query includes any of the positive signals. | |
pattern = re.compile("|".join(re.escape(keyword) for keyword in real_estate_keywords), re.IGNORECASE) | |
positive_signal = bool(pattern.search(query)) | |
# Proceed with LLM classification regardless, but use the positive signal in fallback. | |
llm_local = ChatQwen(temperature=0.3, streaming=False) | |
transform_suggest_to_list = query.lower().replace("suggest ", "list ", -1) | |
system = """ | |
Classify the user query as: | |
- **"search"**: if it requests property listings with specific filters (e.g., location, price, property type like "2bhk", service charges, pet policies, etc.). | |
- **"suggest"**: if it asks for property suggestions without filters. | |
- **"detail"**: if it is asking for more information about a previously provided property (for example, "tell me more about property 5" or "I want more information regarding 4BHK"). | |
- **"general"**: for all other real estate-related questions. | |
- **"out_of_domain"**: if the query is not related to real estate (for example, tourist attractions, restaurants, etc.). | |
Keep in mind that queries mentioning terms like "service charge", "allow pets", "pet rules", etc., are considered real estate queries. | |
Return only the keyword: search, suggest, detail, general, or out_of_domain. | |
""" | |
human_str = f"Here is the query:\n{transform_suggest_to_list}" | |
router_prompt = [ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": human_str}, | |
] | |
response = llm_local.invoke(messages=router_prompt) | |
response_text = response.content if isinstance(response, AIMessage) else str(response) | |
route_value = str(response_text).strip().lower() | |
# Fallback: if the query seems like a detailed request, override. | |
detail_phrases = [ | |
"more information", | |
"tell me more", | |
"more details", | |
"give me more details", | |
"i need more details", | |
"can you provide more details", | |
"additional details", | |
"further information", | |
"expand on that", | |
"explain further", | |
"elaborate more", | |
"more specifics", | |
"i want to know more", | |
"could you elaborate", | |
"need more info", | |
"provide more details", | |
"detail it further", | |
"in-depth information", | |
"break it down further", | |
"further explanation" | |
] | |
if any(phrase in query.lower() for phrase in detail_phrases): | |
route_value = "detail" | |
if route_value not in {"search", "suggest", "detail", "general", "out_of_domain"}: | |
route_value = "general" | |
if route_value == "out_of_domain" and positive_signal: | |
route_value = "general" | |
if route_value == "out_of_domain": | |
route_value = "general" if positive_signal else "out_of_domain" | |
return {"route": route_value} | |
# ------------------------ Workflow Setup ------------------------ | |
workflow = StateGraph(state_schema=dict) | |
def route_query(state: dict) -> dict: | |
new_state = state.copy() | |
try: | |
new_state["route"] = determine_route.invoke(new_state.get("query", "")).get("route", "general") | |
print(new_state["route"]) | |
except Exception as e: | |
print(f"Routing error: {e}") | |
new_state["route"] = "general" | |
return new_state | |
def hybrid_extract(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["filters"] = extract_filters.invoke(new_state.get("query", "")).get("filters", {}) | |
return new_state | |
def search_faiss(state: dict) -> dict: | |
new_state = state.copy() | |
query_embedding = st_model.encode([state["query"]]) | |
_, indices = index.search(query_embedding.astype(np.float32), 5) | |
new_state["faiss_results"] = [docs[idx] for idx in indices[0] if idx < len(docs)] | |
return new_state | |
def apply_filters(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["final_results"] = apply_filters_partial(state["faiss_results"], state.get("filters", {})) | |
return new_state | |
def suggest_properties(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["suggestions"] = random.sample(docs, 5) | |
return new_state | |
def handle_out_of_domain(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["response"] = "I only handle real estate inquiries. Please ask a question related to properties." | |
return new_state | |
def generate_response(state: dict) -> dict: | |
new_state = state.copy() | |
messages = [] | |
# Add the general query prompt. | |
messages.append({"role": "system", "content": general_query_prompt}) | |
# If this is a detail query, add a system message that forces a detailed answer. | |
if new_state.get("route", "general") == "detail": | |
messages.append({ | |
"role": "system", | |
"content": ( | |
"This is a detail query. Please provide detailed information about the property below. " | |
"Do not generate a new list of properties; only use the provided property details to answer the query. " | |
"Focus on answering the specific question (for example, whether pets are allowed)." | |
) | |
}) | |
# If property details are available, add them without clearing context. | |
if new_state.get("current_properties"): | |
property_context = format_property_data(new_state["current_properties"]) | |
messages.append({"role": "system", "content": "Available Property:\n" + property_context}) | |
# Do NOT clear current_properties here. | |
messages.append({"role": "system", "content": "When responding, use only the provided property details to answer the user's specific question about the property."}) | |
# Add the conversation history. | |
for msg in state.get("messages", []): | |
if msg["role"] == "user": | |
messages.append({"role": "user", "content": msg["content"]}) | |
else: | |
messages.append({"role": "assistant", "content": msg["content"]}) | |
# Invoke the LLM with the constructed messages. | |
connection_id = state.get("connection_id") | |
loop = state.get("loop") | |
if connection_id and loop: | |
print("Yes") | |
callback_manager = [WebSocketStreamingCallbackHandler(connection_id, loop)] | |
_ = llm.invoke( | |
messages, | |
config={"callbacks": callback_manager} | |
) | |
new_state["response"] = "" | |
else: | |
callback_manager = [StreamingStdOutCallbackHandler()] | |
response = llm.invoke( | |
messages, | |
config={"callbacks": callback_manager} | |
) | |
new_state["response"] = response.content if isinstance(response, AIMessage) else str(response) | |
return new_state | |
def format_final_response(state: dict) -> dict: | |
new_state = state.copy() | |
# Only override the current_properties if this is NOT a detail query. | |
if not state.get("route", "general") == "detail": | |
if state.get("route") in ["search", "suggest"]: | |
if "final_results" in state: | |
new_state["current_properties"] = state["final_results"] | |
elif "suggestions" in state: | |
new_state["current_properties"] = state["suggestions"] | |
# Then format the response based on the (possibly filtered) current_properties. | |
if new_state.get("current_properties"): | |
formatted = [] | |
for idx, prop in enumerate(new_state["current_properties"], 1): | |
cost = prop.get("totalCosts", "N/A") | |
cost_str = f"{cost:,}" if isinstance(cost, (int, float)) else cost | |
formatted.append( | |
f"{idx}. Type: {prop['propertyType']}, Cost: AED {cost_str}, " | |
f"Size: {prop.get('propertySize', 'N/A')}, Amenities: {', '.join(map(str, prop.get('amenities', []))) if prop.get('amenities') else 'N/A'}, " | |
f"Rental Yield: {prop.get('expectedRentalYield', 'N/A')}, " | |
f"Ownership: {prop.get('ownershipType', 'N/A')}\n" | |
) | |
aggregated_response = "Here are the property details:\n" + "\n".join(formatted) | |
connection_id = state.get("connection_id") | |
loop = state.get("loop") | |
if connection_id and loop: | |
import time | |
tokens = aggregated_response.split(" ") | |
for token in tokens: | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(connection_id, token + " "), | |
loop | |
) | |
time.sleep(0.05) | |
new_state["response"] = "" | |
else: | |
new_state["response"] = aggregated_response | |
elif "response" in new_state: | |
new_state["response"] = str(new_state["response"]) | |
return new_state | |
nodes = [ | |
("route_query", route_query), | |
("hybrid_extract", hybrid_extract), | |
("faiss_search", search_faiss), | |
("apply_filters", apply_filters), | |
("suggest_properties", suggest_properties), | |
("handle_out_of_domain", handle_out_of_domain), | |
("generate_response", generate_response), | |
("format_response", format_final_response) | |
] | |
for name, node in nodes: | |
workflow.add_node(name, node) | |
workflow.add_edge(START, "route_query") | |
workflow.add_conditional_edges( | |
"route_query", | |
lambda state: state.get("route", "general"), | |
{ | |
"search": "hybrid_extract", | |
"suggest": "suggest_properties", | |
"detail": "generate_response", | |
"general": "generate_response", | |
"out_of_domain": "handle_out_of_domain" | |
} | |
) | |
workflow.add_edge("hybrid_extract", "faiss_search") | |
workflow.add_edge("faiss_search", "apply_filters") | |
workflow.add_edge("apply_filters", "format_response") | |
workflow.add_edge("suggest_properties", "format_response") | |
workflow.add_edge("generate_response", "format_response") | |
workflow.add_edge("handle_out_of_domain", "format_response") | |
workflow.add_edge("format_response", END) | |
workflow_app = workflow.compile() | |
# ------------------------ Conversation Manager ------------------------ | |
class ConversationManager: | |
def __init__(self): | |
self.conversation_history = [] | |
self.current_properties = [] | |
def _add_message(self, role: str, content: str): | |
self.conversation_history.append({ | |
"role": role, | |
"content": content, | |
"timestamp": datetime.now().isoformat() | |
}) | |
def process_query(self, query: str) -> str: | |
# Reset context on greetings to avoid using off-domain history | |
if query.strip().lower() in {"hi", "hello", "hey"}: | |
self.conversation_history = [] | |
self.current_properties = [] | |
greeting_response = "Hello! How can I assist you today with your real estate inquiries?" | |
self._add_message("assistant", greeting_response) | |
return greeting_response | |
try: | |
self._add_message("user", query) | |
initial_state = { | |
"messages": self.conversation_history.copy(), | |
"query": query, | |
"route": "general", | |
"filters": {}, | |
"current_properties": self.current_properties | |
} | |
for event in workflow_app.stream(initial_state, stream_mode="values"): | |
final_state = event | |
if 'final_results' in final_state: | |
self.current_properties = final_state['final_results'] | |
elif 'suggestions' in final_state: | |
self.current_properties = final_state['suggestions'] | |
if final_state.get("route") == "general": | |
response_text = final_state.get("response", "") | |
self._add_message("assistant", response_text) | |
return response_text | |
else: | |
response = final_state.get("response", "I couldn't process that request.") | |
self._add_message("assistant", response) | |
return response | |
except Exception as e: | |
print(f"Processing error: {e}") | |
return "Sorry, I encountered an error processing your request." | |
conversation_managers = {} | |
# ------------------------ FastAPI Backend with WebSockets ------------------------ | |
app = FastAPI() | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections = {} | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
connection_id = str(uuid.uuid4()) | |
self.active_connections[connection_id] = websocket | |
print(f"New connection: {connection_id}") | |
return connection_id | |
def disconnect(self, connection_id: str): | |
if connection_id in self.active_connections: | |
del self.active_connections[connection_id] | |
print(f"Disconnected: {connection_id}") | |
async def send_message(self, connection_id: str, message: str): | |
websocket = self.active_connections.get(connection_id) | |
if websocket: | |
await websocket.send_text(message) | |
manager_socket = ConnectionManager() | |
def stream_query(query: str, connection_id: str, loop): | |
conv_manager = conversation_managers.get(connection_id) | |
if conv_manager is None: | |
print(f"No conversation manager found for connection {connection_id}") | |
return | |
# Check for greetings and handle them immediately | |
if query.strip().lower() in {"hi", "hello", "hey"}: | |
conv_manager.conversation_history = [] | |
conv_manager.current_properties = [] | |
greeting_response = "Hello! How can I assist you today with your real estate inquiries?" | |
conv_manager._add_message("assistant", greeting_response) | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(connection_id, greeting_response), | |
loop | |
) | |
return | |
conv_manager._add_message("user", query) | |
initial_state = { | |
"messages": conv_manager.conversation_history.copy(), | |
"query": query, | |
"route": "general", | |
"filters": {}, | |
"current_properties": conv_manager.current_properties, | |
"connection_id": connection_id, | |
"loop": loop | |
} | |
try: | |
workflow_app.invoke(initial_state) | |
except Exception as e: | |
error_msg = f"Error processing query: {str(e)}" | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(connection_id, error_msg), | |
loop | |
) | |
async def websocket_endpoint(websocket: WebSocket): | |
connection_id = await manager_socket.connect(websocket) | |
conversation_managers[connection_id] = ConversationManager() | |
try: | |
while True: | |
query = await websocket.receive_text() | |
loop = asyncio.get_event_loop() | |
# loop = asyncio.get_running_loop() | |
threading.Thread( | |
target=stream_query, | |
args=(query, connection_id, loop), | |
daemon=True | |
).start() | |
except WebSocketDisconnect: | |
conv_manager = conversation_managers.get(connection_id) | |
if conv_manager: | |
filename = f"conversations/conversation_{connection_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
with open(filename, "w") as f: | |
json.dump(conv_manager.conversation_history, f, indent=4) | |
del conversation_managers[connection_id] | |
manager_socket.disconnect(connection_id) | |
async def post_query(query: str): | |
conv_manager = ConversationManager() | |
response = conv_manager.process_query(query) | |
return {"response": response} | |