Spaces:
Sleeping
Sleeping
import uuid | |
import threading | |
import asyncio | |
import json | |
import re | |
from datetime import datetime | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
# ------------------------ Chatbot Code (Unmodified) ------------------------ | |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
from langgraph.graph import StateGraph, START, END | |
# from langchain_ollama import ChatOllama | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
import pickle | |
import numpy as np | |
from tools import extract_json_from_response, apply_filters_partial, rule_based_extract, format_property_data, estateKeywords | |
import random | |
from langchain_core.tools import tool | |
from langchain_core.callbacks import StreamingStdOutCallbackHandler, CallbackManager | |
from langchain_core.callbacks.base import BaseCallbackHandler | |
# ------------------------ Custom Callback for WebSocket Streaming ------------------------ | |
class WebSocketStreamingCallbackHandler(BaseCallbackHandler): | |
def __init__(self, connection_id: str, loop): | |
self.connection_id = connection_id | |
self.loop = loop | |
def on_llm_new_token(self, token: str, **kwargs): | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(self.connection_id, token), | |
self.loop | |
) | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
class ChatHuggingFace: | |
def __init__(self, model, token, temperature=0.3, streaming=False): | |
# Instead of using InferenceClient, load the model locally. | |
self.temperature = temperature | |
self.streaming = streaming | |
self.tokenizer = AutoTokenizer.from_pretrained(model) | |
self.model = AutoModelForCausalLM.from_pretrained(model) | |
self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer) | |
def invoke(self, messages, config=None): | |
""" | |
Mimics the ChatOllama.invoke interface. | |
In streaming mode, token-by-token output is sent via callbacks. | |
Otherwise, returns a single aggregated response. | |
""" | |
config = config or {} | |
callbacks = config.get("callbacks", []) | |
aggregated_response = "" | |
# Build the prompt by concatenating messages in the expected format. | |
prompt = "" | |
for msg in messages: | |
role = msg.get("role", "") | |
content = msg.get("content", "") | |
if role == "system": | |
prompt += f"<|im_start|>system\n{content}\n<|im_end|>\n" | |
elif role == "user": | |
prompt += f"<|im_start|>user\n{content}\n<|im_end|>\n" | |
elif role == "assistant": | |
prompt += f"<|im_start|>assistant\n{content}\n<|im_end|>\n" | |
if self.streaming: | |
# Generate text locally. | |
full_output = self.pipeline( | |
prompt, | |
max_new_tokens=100, | |
do_sample=True, | |
temperature=self.temperature | |
)[0]['generated_text'] | |
# Assume the pipeline returns the prompt + generated text. | |
new_text = full_output[len(prompt):] | |
# Simulate token-by-token streaming. | |
for token in new_text.split(): | |
aggregated_response += token + " " | |
for cb in callbacks: | |
cb.on_llm_new_token(token=token + " ") | |
return type("AIMessage", (object,), {"content": aggregated_response.strip()}) | |
else: | |
# Non-streaming mode. | |
response = self.pipeline( | |
prompt, | |
max_new_tokens=100, | |
do_sample=True, | |
temperature=self.temperature | |
)[0]['generated_text'] | |
new_text = response[len(prompt):] | |
return type("AIMessage", (object,), {"content": new_text.strip()}) | |
# ------------------------ LLM and Data Setup ------------------------ | |
# model_name="qwen2.5:1.5b" | |
model_name="Qwen/Qwen2.5-1.5B-Instruct" | |
# llm = ChatOllama(model=model_name, temperature=0.3, streaming=True) | |
llm = ChatHuggingFace( | |
model=model_name, | |
# token=token, | |
temperature=0.3, | |
streaming=True # or True, based on your needs | |
) | |
index = faiss.read_index("./faiss.index") | |
with open("./metadata.pkl", "rb") as f: | |
docs = pickle.load(f) | |
st_model = SentenceTransformer('all-MiniLM-L6-v2') | |
def make_system_prompt(suffix: str) -> str: | |
return ( | |
"You are EstateGuru, a real estate expert created by Abhishek Pathak from SwavishTek. " | |
"Your role is to help customers buy properties using the available data. " | |
"Only use the provided data—do not make up any information. " | |
"The default currency is AED. If a query uses a different currency, convert the amount to AED " | |
"(for example, $10k becomes 36726.50 AED and $1 becomes 3.67 AED). " | |
"If a customer is interested in a property, wants to buy, or needs to contact an agent or customer care, " | |
"instruct them to call +91 8766268285." | |
f"\n{suffix}" | |
) | |
general_query_prompt = make_system_prompt( | |
"You are EstateGuru, a helpful real estate assistant. Answer the user's query accurately using the available data. " | |
"Do not invent any details or go beyond the real estate domain. " | |
"If the user shows interest in a property or contacting an agent, ask them to call +91 8766268285." | |
) | |
# ------------------------ Tool Definitions ------------------------ | |
def extract_filters(query: str) -> dict: | |
"""For extracting filters""" | |
# llm_local = ChatOllama(model=model_name, temperature=0.3) | |
llm_local = ChatHuggingFace( | |
model=model_name, | |
# token=token, | |
temperature=0.3, | |
streaming=False | |
) | |
system = ( | |
"You are an expert in extracting filters from property-related queries. Your task is to extract and return only the keys explicitly mentioned in the query as a valid JSON object (starting with '{{' and ending with '}}'). Include only those keys that are directly present in the query.\n\n" | |
"The possible keys are:\n" | |
" - 'projectName': The name of the project.\n" | |
" - 'developerName': The developer's name.\n" | |
" - 'relationshipManager': The relationship manager.\n" | |
" - 'propertyAddress': The property address.\n" | |
" - 'surroundingArea': The area or nearby landmarks.\n" | |
" - 'propertyType': The type or configuration of the property.\n" | |
" - 'amenities': Any amenities mentioned.\n" | |
" - 'coveredParking': Parking availability.\n" | |
" - 'petRules': Pet policies.\n" | |
" - 'security': Security details.\n" | |
" - 'occupancyRate': Occupancy information.\n" | |
" - 'constructionImpact': Construction or its impact.\n" | |
" - 'propertySize': Size of the property.\n" | |
" - 'propertyView': View details.\n" | |
" - 'propertyCondition': Condition of the property.\n" | |
" - 'serviceCharges': Service or maintenance charges.\n" | |
" - 'ownershipType': Ownership type.\n" | |
" - 'totalCosts': A cost threshold or cost amount.\n" | |
" - 'paymentPlans': Payment or financing plans.\n" | |
" - 'expectedRentalYield': Expected rental yield.\n" | |
" - 'rentalHistory': Rental history.\n" | |
" - 'shortTermRentals': Short-term rental information.\n" | |
" - 'resalePotential': Resale potential.\n" | |
" - 'uniqueId': A unique identifier.\n\n" | |
"Important instructions regarding cost thresholds:\n" | |
" - If the query contains phrases like 'under 10k', 'below 2m', or 'less than 5k', interpret these as cost thresholds.\n" | |
" - Convert any shorthand cost values to pure numbers (for example, '10k' becomes 10000, '2m' becomes 2000000) and assign them to the key 'totalCosts'.\n" | |
" - Do not use 'propertySize' for cost thresholds.\n\n" | |
" - Default currency is AED, if user query have different currency symbol then convert to equivalent AED amount (eg. $10k becomes 36726.50, $1 becomes 3.67).\n\n" | |
"Example:\n" | |
" For the query: \"properties near dubai mall under 43k\"\n" | |
" The expected output should be:\n" | |
" {{ \"surroundingArea\": \"dubai mall\", \"totalCosts\": 43000 }}\n\n" | |
"Return ONLY a valid JSON object with the extracted keys and their corresponding values, with no additional text." | |
) | |
human_str = f"Here is the query:\n{query}" | |
filter_prompt = [ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": human_str}, | |
] | |
response = llm_local.invoke(messages=filter_prompt) | |
response_text = response.content if isinstance(response, AIMessage) else str(response) | |
try: | |
model_filters = extract_json_from_response(response_text) | |
except Exception as e: | |
print(f"JSON parsing error: {e}") | |
model_filters = {} | |
rule_filters = rule_based_extract(query) | |
print("Rule-based extraction:", rule_filters) | |
final_filters = {**model_filters, **rule_filters} | |
print("Final extraction:", final_filters) | |
return {"filters": final_filters} | |
def determine_route(query: str) -> dict: | |
"""For determining route using enhanced prompt and fallback logic.""" | |
# Define a set of keywords that are strong indicators of a real estate query. | |
real_estate_keywords = estateKeywords | |
# Check if the query includes any of the positive signals. | |
pattern = re.compile("|".join(re.escape(keyword) for keyword in real_estate_keywords), re.IGNORECASE) | |
positive_signal = bool(pattern.search(query)) | |
# Proceed with LLM classification regardless, but use the positive signal in fallback. | |
# llm_local = ChatOllama(model=model_name, temperature=0.3) | |
llm_local = ChatHuggingFace( | |
model=model_name, | |
# token=token, | |
temperature=0.3, | |
streaming=False | |
) | |
transform_suggest_to_list = query.lower().replace("suggest ", "list ", -1) | |
system = """ | |
Classify the user query as: | |
- **"search"**: if it requests property listings with specific filters (e.g., location, price, property type like "2bhk", service charges, pet policies, etc.). | |
- **"suggest"**: if it asks for property suggestions without filters. | |
- **"detail"**: if it is asking for more information about a previously provided property (e.g., "tell me more about property 5" or "I want more information regarding 4BHK"). | |
- **"general"**: for all other real estate-related questions. | |
- **"out_of_domain"**: if the query is not related to real estate (for example, tourist attractions, restaurants, etc.). | |
Keep in mind that queries mentioning terms like "service charge", "allow pets", "pet rules", etc., are considered real estate queries. | |
Return only the keyword: search, suggest, detail, general, or out_of_domain. | |
""" | |
human_str = f"Here is the query:\n{transform_suggest_to_list}" | |
filter_prompt = [ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": human_str}, | |
] | |
response = llm_local.invoke(messages=filter_prompt) | |
response_text = response.content if isinstance(response, AIMessage) else str(response) | |
route_value = str(response_text).strip().lower() | |
# Fallback: if no positive real estate signal is found, override to out_of_domain. | |
# if not positive_signal: | |
# route_value = "out_of_domain" | |
# Fallback | |
detail_phrases = [ | |
"more information", | |
"tell me more", | |
"more details", | |
"give me more details", | |
"I need more details", | |
"can you provide more details", | |
"additional details", | |
"further information", | |
"expand on that", | |
"explain further", | |
"elaborate more", | |
"more specifics", | |
"I want to know more", | |
"could you elaborate", | |
"need more info", | |
"provide more details", | |
"detail it further", | |
"in-depth information", | |
"break it down further", | |
"further explanation" | |
] | |
if any(phrase in query.lower() for phrase in detail_phrases): | |
route_value = "detail" | |
if route_value not in {"search", "suggest", "detail", "general", "out_of_domain"}: | |
route_value = "general" | |
if route_value == "out_of_domain" and positive_signal: | |
route_value = "general" | |
if route_value == "out_of_domain": | |
# If positive real estate signal exists, treat it as "general". | |
route_value = "general" if positive_signal else "out_of_domain" | |
return {"route": route_value} | |
# ------------------------ Workflow Setup ------------------------ | |
workflow = StateGraph(state_schema=dict) | |
def route_query(state: dict) -> dict: | |
new_state = state.copy() | |
try: | |
new_state["route"] = determine_route.invoke(new_state.get("query", "")).get("route", "general") | |
print(new_state["route"]) | |
except Exception as e: | |
print(f"Routing error: {e}") | |
new_state["route"] = "general" | |
return new_state | |
def hybrid_extract(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["filters"] = extract_filters.invoke(new_state.get("query", "")).get("filters", {}) | |
return new_state | |
def search_faiss(state: dict) -> dict: | |
new_state = state.copy() | |
query_embedding = st_model.encode([state["query"]]) | |
_, indices = index.search(query_embedding.astype(np.float32), 5) | |
new_state["faiss_results"] = [docs[idx] for idx in indices[0] if idx < len(docs)] | |
return new_state | |
def apply_filters(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["final_results"] = apply_filters_partial(state["faiss_results"], state.get("filters", {})) | |
return new_state | |
def suggest_properties(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["suggestions"] = random.sample(docs, 5) | |
return new_state | |
def handle_out_of_domain(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["response"] = "I only handle real estate inquiries. Please ask a question related to properties." | |
return new_state | |
def generate_response(state: dict) -> dict: | |
new_state = state.copy() | |
detail_query_flag = False | |
# --- Disambiguate specific property requests using property number --- | |
property_match = re.search(r"(?:the\s+)?property\s*(\d+)\b", state.get("query", ""), re.IGNORECASE) | |
if property_match and new_state.get("current_properties"): | |
try: | |
index_requested = int(property_match.group(1)) - 1 | |
if 0 <= index_requested < len(new_state["current_properties"]): | |
new_state["current_properties"] = [new_state["current_properties"][index_requested]] | |
detail_query_flag = True | |
new_state["detail_query"] = True | |
except Exception as e: | |
print(f"Property selection error: {e}") | |
# Construct messages for the LLM. | |
messages = [] | |
# Add the general query prompt. | |
messages.append(SystemMessage(content=general_query_prompt)) | |
# If this is a detail query, add a system message that forces a detailed answer. | |
if detail_query_flag: | |
messages.append(SystemMessage(content=( | |
"This is a detail query. Please provide detailed information about the property below. " | |
"Do not generate a new list of properties; only use the provided property details to answer the query. " | |
"Focus on answering the specific question (for example, whether pets are allowed)." | |
))) | |
# Provide the current property context. | |
if new_state.get("current_properties"): | |
property_context = format_property_data(new_state["current_properties"]) | |
messages.insert(0, SystemMessage(content="Available Property:\n" + property_context)) | |
# Add the conversation history. | |
for msg in state.get("messages", []): | |
if msg["role"] == "user": | |
messages.append(HumanMessage(content=msg["content"])) | |
else: | |
messages.append(AIMessage(content=msg["content"])) | |
# Instruction for response. | |
messages.append(SystemMessage(content=( | |
"When responding, use only the provided property details to answer the user's specific question about the property." | |
))) | |
# Invoke the LLM with the constructed messages. | |
connection_id = state.get("connection_id") | |
loop = state.get("loop") | |
if connection_id and loop: | |
callback_manager = CallbackManager([WebSocketStreamingCallbackHandler(connection_id, loop)]) | |
_ = llm.invoke( | |
messages=messages, | |
config={"callbacks": callback_manager} | |
) | |
new_state["response"] = "" | |
else: | |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
response = llm.invoke( | |
messages=messages, | |
config={"callbacks": callback_manager} | |
) | |
new_state["response"] = response.content if isinstance(response, AIMessage) else str(response) | |
return new_state | |
def format_final_response(state: dict) -> dict: | |
new_state = state.copy() | |
# Only override the current_properties if this is NOT a detail query. | |
if not state.get("detail_query", False): | |
if state.get("route") in ["search", "suggest"]: | |
if "final_results" in state: | |
new_state["current_properties"] = state["final_results"] | |
elif "suggestions" in state: | |
new_state["current_properties"] = state["suggestions"] | |
# Then format the response based on the (possibly filtered) current_properties. | |
if new_state.get("current_properties"): | |
formatted = [] | |
for idx, prop in enumerate(new_state["current_properties"], 1): | |
cost = prop.get("totalCosts", "N/A") | |
cost_str = f"{cost:,}" if isinstance(cost, (int, float)) else cost | |
formatted.append( | |
f"{idx}. Type: {prop['propertyType']}, Cost: AED {cost_str}, " | |
f"Size: {prop.get('propertySize', 'N/A')}, Amenities: {', '.join(map(str, prop.get('amenities', []))) if prop.get('amenities') else 'N/A'}, " | |
f"Rental Yield: {prop.get('expectedRentalYield', 'N/A')}, " | |
f"Ownership: {prop.get('ownershipType', 'N/A')}\n" | |
) | |
aggregated_response = "Here are the property details:\n" + "\n".join(formatted) | |
connection_id = state.get("connection_id") | |
loop = state.get("loop") | |
if connection_id and loop: | |
import time | |
tokens = aggregated_response.split(" ") | |
for token in tokens: | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(connection_id, token + " "), | |
loop | |
) | |
time.sleep(0.05) | |
new_state["response"] = "" | |
else: | |
new_state["response"] = aggregated_response | |
elif "response" in new_state: | |
new_state["response"] = str(new_state["response"]) | |
return new_state | |
nodes = [ | |
("route_query", route_query), | |
("hybrid_extract", hybrid_extract), | |
("faiss_search", search_faiss), | |
("apply_filters", apply_filters), | |
("suggest_properties", suggest_properties), | |
("handle_out_of_domain", handle_out_of_domain), | |
("generate_response", generate_response), | |
("format_response", format_final_response) | |
] | |
for name, node in nodes: | |
workflow.add_node(name, node) | |
workflow.add_edge(START, "route_query") | |
workflow.add_conditional_edges( | |
"route_query", | |
lambda state: state.get("route", "general"), | |
{ | |
"search": "hybrid_extract", | |
"suggest": "suggest_properties", | |
"detail": "generate_response", | |
"general": "generate_response", | |
"out_of_domain": "handle_out_of_domain" | |
} | |
) | |
workflow.add_edge("hybrid_extract", "faiss_search") | |
workflow.add_edge("faiss_search", "apply_filters") | |
workflow.add_edge("apply_filters", "format_response") | |
workflow.add_edge("suggest_properties", "format_response") | |
workflow.add_edge("generate_response", "format_response") | |
workflow.add_edge("handle_out_of_domain", "format_response") | |
workflow.add_edge("format_response", END) | |
workflow_app = workflow.compile() | |
# ------------------------ Conversation Manager ------------------------ | |
class ConversationManager: | |
def __init__(self): | |
self.conversation_history = [] | |
self.current_properties = [] | |
def _add_message(self, role: str, content: str): | |
self.conversation_history.append({ | |
"role": role, | |
"content": content, | |
"timestamp": datetime.now().isoformat() | |
}) | |
def process_query(self, query: str) -> str: | |
# Reset context on greetings to avoid using off-domain history | |
if query.strip().lower() in {"hi", "hello", "hey"}: | |
self.conversation_history = [] | |
self.current_properties = [] | |
greeting_response = "Hello! How can I assist you today with your real estate inquiries?" | |
self._add_message("assistant", greeting_response) | |
return greeting_response | |
try: | |
self._add_message("user", query) | |
initial_state = { | |
"messages": self.conversation_history.copy(), | |
"query": query, | |
"route": "general", | |
"filters": {}, | |
"current_properties": self.current_properties | |
} | |
for event in workflow_app.stream(initial_state, stream_mode="values"): | |
final_state = event | |
if 'final_results' in final_state: | |
self.current_properties = final_state['final_results'] | |
elif 'suggestions' in final_state: | |
self.current_properties = final_state['suggestions'] | |
if final_state.get("route") == "general": | |
response_text = final_state.get("response", "") | |
self._add_message("assistant", response_text) | |
return response_text | |
else: | |
response = final_state.get("response", "I couldn't process that request.") | |
self._add_message("assistant", response) | |
return response | |
except Exception as e: | |
print(f"Processing error: {e}") | |
return "Sorry, I encountered an error processing your request." | |
conversation_managers = {} | |
# ------------------------ FastAPI Backend with WebSockets ------------------------ | |
app = FastAPI() | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections = {} | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
connection_id = str(uuid.uuid4()) | |
self.active_connections[connection_id] = websocket | |
print(f"New connection: {connection_id}") | |
return connection_id | |
def disconnect(self, connection_id: str): | |
if connection_id in self.active_connections: | |
del self.active_connections[connection_id] | |
print(f"Disconnected: {connection_id}") | |
async def send_message(self, connection_id: str, message: str): | |
websocket = self.active_connections.get(connection_id) | |
if websocket: | |
await websocket.send_text(message) | |
manager_socket = ConnectionManager() | |
def stream_query(query: str, connection_id: str, loop): | |
conv_manager = conversation_managers.get(connection_id) | |
if conv_manager is None: | |
print(f"No conversation manager found for connection {connection_id}") | |
return | |
# Check for greetings and handle them immediately | |
if query.strip().lower() in {"hi", "hello", "hey"}: | |
conv_manager.conversation_history = [] | |
conv_manager.current_properties = [] | |
greeting_response = "Hello! How can I assist you today with your real estate inquiries?" | |
conv_manager._add_message("assistant", greeting_response) | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(connection_id, greeting_response), | |
loop | |
) | |
return | |
conv_manager._add_message("user", query) | |
initial_state = { | |
"messages": conv_manager.conversation_history.copy(), | |
"query": query, | |
"route": "general", | |
"filters": {}, | |
"current_properties": conv_manager.current_properties, | |
"connection_id": connection_id, | |
"loop": loop | |
} | |
try: | |
workflow_app.invoke(initial_state) | |
except Exception as e: | |
error_msg = f"Error processing query: {str(e)}" | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(connection_id, error_msg), | |
loop | |
) | |
async def websocket_endpoint(websocket: WebSocket): | |
connection_id = await manager_socket.connect(websocket) | |
conversation_managers[connection_id] = ConversationManager() | |
try: | |
while True: | |
query = await websocket.receive_text() | |
loop = asyncio.get_event_loop() | |
threading.Thread( | |
target=stream_query, | |
args=(query, connection_id, loop), | |
daemon=True | |
).start() | |
except WebSocketDisconnect: | |
conv_manager = conversation_managers.get(connection_id) | |
if conv_manager: | |
filename = f"conversations/conversation_{connection_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
with open(filename, "w") as f: | |
json.dump(conv_manager.conversation_history, f, indent=4) | |
del conversation_managers[connection_id] | |
manager_socket.disconnect(connection_id) | |
async def post_query(query: str): | |
conv_manager = ConversationManager() | |
response = conv_manager.process_query(query) | |
return {"response": response} | |