Spaces:
Sleeping
Sleeping
import uuid | |
import threading | |
import asyncio | |
import json | |
import re | |
import random | |
import time | |
import pickle | |
import numpy as np | |
import requests # For llama.cpp server calls | |
from datetime import datetime | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, Request | |
from langchain_core.messages import AIMessage | |
from langgraph.graph import StateGraph, START, END | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from tools import extract_json_from_response, apply_filters_partial, rule_based_extract, structured_property_data, estateKeywords, sendTokenViaSocket | |
from langchain_core.tools import tool | |
from langchain_core.callbacks import StreamingStdOutCallbackHandler | |
from langchain_core.callbacks.base import BaseCallbackHandler | |
import os | |
from fastapi.responses import PlainTextResponse | |
from fastapi.staticfiles import StaticFiles | |
from functools import lru_cache | |
from contextlib import asynccontextmanager | |
# ------------------------ Model Inference Wrapper ------------------------ | |
class ChatQwen: | |
""" | |
A chat wrapper for Qwen using llama.cpp. | |
This class can work in two modes: | |
- Local: Using a llama-cpp-python binding (gguf model file loaded locally). | |
- Server: Calling a remote llama.cpp server endpoint. | |
""" | |
def __init__( | |
self, | |
temperature=0.3, | |
streaming=False, | |
max_new_tokens=512, | |
callbacks=None, | |
use_server=False, | |
model_path: str = None, | |
server_url: str = None | |
): | |
self.temperature = temperature | |
self.streaming = streaming | |
self.max_new_tokens = max_new_tokens | |
self.callbacks = callbacks | |
self.use_server = use_server | |
self.is_hf_space = os.environ.get('SPACE_ID') is not None | |
if self.use_server: | |
# Use remote llama.cpp server – provide its URL. | |
self.server_url = server_url or "http://localhost:8000" | |
else: | |
# For local inference, a model_path must be provided. | |
if not model_path: | |
raise ValueError("Local mode requires a valid model_path to the gguf file.") | |
from llama_cpp import Llama # assumes llama-cpp-python is installed | |
# self.model = Llama( | |
# model_path=model_path, | |
# temperature=self.temperature, | |
# # n_ctx=512, | |
# n_ctx=8192, | |
# n_threads=4, # Adjust as needed | |
# batch_size=512, | |
# verbose=False, | |
# ) | |
# Update Llama initialization: | |
if self.is_hf_space: | |
self.model = Llama( | |
model_path=model_path, | |
temperature=self.temperature, | |
n_ctx=1024, # Reduced from 8192 | |
n_threads=2, # Never exceed 2 threads on free tier | |
n_batch=128, # Smaller batch size for low RAM | |
use_mmap=True, # Essential for memory mapping | |
use_mlock=False, # Disable memory locking | |
low_vram=True, # Special low-memory mode | |
vocab_only=False, | |
n_gqa=2, # Grouped-query attention for 1.5B model | |
rope_freq_base=10000, | |
logits_all=False, | |
verbose=False, | |
) | |
else: | |
self.model = Llama( | |
model_path=model_path, | |
n_gpu_layers=20, # Offload 20 layers to GPU (adjust based on VRAM) | |
n_threads=3, # leave 1 | |
n_threads_batch=3, | |
batch_size=256, | |
main_gpu=0, # Use first GPU | |
use_mmap=True, | |
use_mlock=False, | |
temperature=self.temperature, | |
n_ctx=2048, # Reduced context for lower memory usage | |
verbose=False | |
) | |
if not self.use_server: | |
self.model.tokenize(b"Warmup") # Pre-load model | |
self.model.create_completion("Warmup", max_tokens=1) | |
# def build_prompt(self, messages: list) -> str: | |
# """Build Qwen-compatible prompt with special tokens.""" | |
# prompt = "" | |
# for msg in messages: | |
# role = msg["role"] | |
# content = msg["content"] | |
# if role == "system": | |
# prompt += f"<|im_start|>system\n{content}<|im_end|>\n" | |
# elif role == "user": | |
# prompt += f"<|im_start|>user\n{content}<|im_end|>\n" | |
# elif role == "assistant": | |
# prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n" | |
# prompt += "<|im_start|>assistant\n" | |
# return prompt | |
def build_prompt(self, messages: list) -> str: | |
"""Optimized prompt builder with string join""" | |
return "".join( | |
f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n" | |
for msg in messages | |
) + "<|im_start|>assistant\n" | |
def generate_text(self, messages: list) -> str: | |
try: | |
prompt = self.build_prompt(messages) | |
stop_tokens = ["<|im_end|>", "\n"] # Qwen's stop sequences | |
if self.use_server: | |
payload = { | |
"prompt": prompt, | |
"max_tokens": self.max_new_tokens, | |
"temperature": self.temperature, | |
"stream": self.streaming, | |
"stop": stop_tokens # Add stop tokens to server request | |
} | |
if self.streaming: | |
response = requests.post(f"{self.server_url}/generate", json=payload, stream=True) | |
generated_text = "" | |
for line in response.iter_lines(): | |
if line: | |
token = line.decode("utf-8") | |
# Check for stop tokens in stream | |
if any(stop in token for stop in stop_tokens): | |
break | |
generated_text += token | |
if self.callbacks: | |
for callback in self.callbacks: | |
callback.on_llm_new_token(token) | |
return generated_text | |
else: | |
response = requests.post(f"{self.server_url}/generate", json=payload) | |
return response.json().get("generated_text", "") | |
else: | |
# Local llama.cpp inference | |
if self.streaming: | |
if self.is_hf_space: | |
stream = self.model.create_completion( | |
prompt=prompt, | |
max_tokens=256, # Reduced from 512 | |
temperature=0.3, | |
stream=True, | |
stop=stop_tokens, | |
repeat_penalty=1.15, | |
frequency_penalty=0.2, | |
mirostat_mode=2, # Better for low-resource | |
mirostat_tau=3.0, | |
mirostat_eta=0.1 | |
) | |
else: | |
stream = self.model.create_completion( | |
prompt=prompt, | |
max_tokens=self.max_new_tokens, | |
temperature=self.temperature, | |
stream=True, | |
stop=stop_tokens, | |
repeat_penalty=1.1, # Reduce repetition for faster generation | |
tfs_z=0.5 # Tail-free sampling for efficiency | |
) | |
generated_text = "" | |
for token_chunk in stream: | |
token_text = token_chunk["choices"][0]["text"] | |
# Stop early if we detect end token | |
if any(stop in token_text for stop in stop_tokens): | |
break | |
generated_text += token_text | |
if self.callbacks: | |
for callback in self.callbacks: | |
callback.on_llm_new_token(token_text) | |
return generated_text | |
else: | |
result = self.model.create_completion( | |
prompt=prompt, | |
max_tokens=self.max_new_tokens, | |
temperature=self.temperature, | |
stop=stop_tokens | |
) | |
return result["choices"][0]["text"] | |
except Exception as e: | |
if "out of memory" in str(e).lower() and self.is_hf_space: | |
return self.fallback_generate(messages) | |
def fallback_generate(self, messages): | |
"""Simpler generation for OOM situations""" | |
return self.model.create_completion( | |
prompt=self.build_prompt(messages), | |
max_tokens=128, | |
temperature=0.3, | |
stream=False, | |
stop=["<|im_end|>", "\n"] | |
)["choices"][0]["text"] | |
def invoke(self, messages: list, config: dict = None) -> AIMessage: | |
config = config or {} | |
callbacks = config.get("callbacks", self.callbacks) | |
original_callbacks = self.callbacks | |
self.callbacks = callbacks | |
output_text = self.generate_text(messages) | |
self.callbacks = original_callbacks | |
# In streaming mode we return an empty content as tokens are being sent via callbacks. | |
if self.streaming: | |
return AIMessage(content="") | |
else: | |
return AIMessage(content=output_text) | |
def __call__(self, messages: list) -> AIMessage: | |
return self.invoke(messages) | |
# ------------------------ Callback for WebSocket Streaming ------------------------ | |
class WebSocketStreamingCallbackHandler(BaseCallbackHandler): | |
def __init__(self, connection_id: str, loop): | |
self.connection_id = connection_id | |
self.loop = loop | |
def on_llm_new_token(self, token: str, **kwargs): | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(self.connection_id, token), | |
self.loop | |
) | |
# ------------------------ Instantiate the LLM ------------------------ | |
# Choose one mode: local (set use_server=False) or server (set use_server=True). | |
model_path="qwen2.5-1.5b-instruct-q4_k_m.gguf" | |
llm = ChatQwen( | |
temperature=0.3, | |
streaming=True, | |
max_new_tokens=512, | |
use_server=False, | |
model_path=model_path, | |
# server_url="http://localhost:8000" # Uncomment and set if using server mode. | |
) | |
llm_no_stream = ChatQwen( | |
temperature=0.3, | |
streaming=False, | |
use_server=False, | |
model_path=model_path, | |
) | |
# ------------------------ FAISS and Sentence Transformer Setup ------------------------ | |
index = faiss.read_index("./faiss.index") | |
with open("./metadata.pkl", "rb") as f: | |
docs = pickle.load(f) | |
st_model = SentenceTransformer('all-MiniLM-L6-v2') | |
def make_system_prompt(suffix: str) -> str: | |
return ( | |
"You are EstateGuru, a real estate expert developed by Abhishek Pathak at SwavishTek. " | |
"Your role is to help customers buy properties using only the provided data—do not invent any details. " | |
"The default currency is AED; if a query mentions another currency, convert the amount to AED " | |
"(for example, convert $10k to 36726.50 AED and $1 to 3.67 AED). " | |
"If a customer is interested in a property or needs to contact an agent, instruct them to call +91 8766268285. " | |
"Keep your answers short, clear, and concise." | |
f"\n{suffix}" | |
) | |
general_query_prompt = make_system_prompt( | |
"You are EstateGuru, a helpful real estate assistant. " | |
"Please respond only in English. " | |
"Convert any prices to USD before answering. " | |
"Provide a brief, direct answer without extra details." | |
) | |
# ------------------------ Tool Definitions ------------------------ | |
def extract_filters(query: str) -> dict: | |
"""Extract filters from the query.""" | |
# llm_local = ChatQwen(temperature=0.3, streaming=False, use_server=False, model_path=model_path) | |
system = ( | |
"You are an expert in extracting filters from property-related queries. Your task is to extract and return only the keys explicitly mentioned in the query as a valid JSON object (starting with '{' and ending with '}'). Include only those keys that are directly present in the query.\n\n" | |
"The possible keys are:\n" | |
" - 'projectName': The name of the project.\n" | |
" - 'developerName': The developer's name.\n" | |
" - 'relationshipManager': The relationship manager.\n" | |
" - 'propertyAddress': The property address.\n" | |
" - 'surroundingArea': The area or nearby landmarks.\n" | |
" - 'propertyType': The type or configuration of the property.\n" | |
" - 'amenities': Any amenities mentioned.\n" | |
" - 'coveredParking': Parking availability.\n" | |
" - 'petRules': Pet policies.\n" | |
" - 'security': Security details.\n" | |
" - 'occupancyRate': Occupancy information.\n" | |
" - 'constructionImpact': Construction or its impact.\n" | |
" - 'propertySize': Size of the property.\n" | |
" - 'propertyView': View details.\n" | |
" - 'propertyCondition': Condition of the property.\n" | |
" - 'serviceCharges': Service or maintenance charges.\n" | |
" - 'ownershipType': Ownership type.\n" | |
" - 'totalCosts': A cost threshold or cost amount.\n" | |
" - 'paymentPlans': Payment or financing plans.\n" | |
" - 'expectedRentalYield': Expected rental yield.\n" | |
" - 'rentalHistory': Rental history.\n" | |
" - 'shortTermRentals': Short-term rental information.\n" | |
" - 'resalePotential': Resale potential.\n" | |
" - 'uniqueId': A unique identifier.\n\n" | |
"Important instructions regarding cost thresholds:\n" | |
" - If the query contains phrases like 'under 10k', 'below 2m', or 'less than 5k', interpret these as cost thresholds.\n" | |
" - Convert any shorthand cost values to pure numbers (for example, '10k' becomes 10000, '2m' becomes 2000000) and assign them to the key 'totalCosts'.\n" | |
" - Do not use 'propertySize' for cost thresholds.\n\n" | |
" - Default currency is AED, if user query have different currency symbol then convert to equivalent AED amount (eg. $10k becomes 36726.50, $1 becomes 3.67).\n\n" | |
"Example:\n" | |
" For the query: \"properties near dubai mall under 43k\"\n" | |
" The expected output should be:\n" | |
" { \"surroundingArea\": \"dubai mall\", \"totalCosts\": 43000 }\n\n" | |
"Return ONLY a valid JSON object with the extracted keys and their corresponding values, with no additional text." | |
) | |
human_str = f"Here is the query:\n{query}" | |
filter_prompt = [ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": human_str}, | |
] | |
response = llm_no_stream.invoke(messages=filter_prompt) | |
response_text = response.content if isinstance(response, AIMessage) else str(response) | |
try: | |
model_filters = extract_json_from_response(response_text) | |
except Exception as e: | |
print(f"JSON parsing error: {e}") | |
model_filters = {} | |
rule_filters = rule_based_extract(query) | |
print("Rule-based extraction:", rule_filters) | |
final_filters = {**model_filters, **rule_filters} | |
print("Final extraction:", final_filters) | |
return {"filters": final_filters} | |
def determine_route(query: str) -> dict: | |
"""Determine the route (search, suggest, detail, general, out_of_domain) for the query.""" | |
real_estate_keywords = estateKeywords | |
pattern = re.compile("|".join(re.escape(keyword) for keyword in real_estate_keywords), re.IGNORECASE) | |
positive_signal = bool(pattern.search(query)) | |
# llm_local = ChatQwen(temperature=0.3, streaming=False, use_server=False, model_path=model_path) | |
transform_suggest_to_list = query.lower().replace("suggest ", "list ", -1) | |
system = """ | |
Classify the user query as: | |
- **"search"**: if it requests property listings with specific filters (e.g., location, price, property type like "2bhk", service charges, pet policies, etc.). | |
- **"suggest"**: if it asks for property suggestions without filters. | |
- **"detail"**: if it is asking for more information about a previously provided property (for example, "tell me more about property 5" or "I want more information regarding 4BHK"). | |
- **"general"**: for all other real estate-related questions. | |
- **"out_of_domain"**: if the query is not related to real estate (for example, tourist attractions, restaurants, etc.). | |
Keep in mind that queries mentioning terms like "service charge", "allow pets", "pet rules", etc., are considered real estate queries. | |
When user asks about you (for example, "who you are", "who made you" etc.) consider as general. | |
Return only the keyword: search, suggest, detail, general, or out_of_domain. | |
""" | |
human_str = f"Here is the query:\n{transform_suggest_to_list}" | |
router_prompt = [ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": human_str}, | |
] | |
response = llm_no_stream.invoke(messages=router_prompt) | |
response_text = response.content if isinstance(response, AIMessage) else str(response) | |
route_value = str(response_text).strip().lower() | |
# --- NEW: Force 'detail' if query explicitly mentions a specific property (e.g., "property 2") --- | |
property_detail_pattern = re.compile(r"property\s+\d+", re.IGNORECASE) | |
if property_detail_pattern.search(query): | |
route_value = "detail" | |
# Fallback override if query appears detailed. | |
detail_phrases = [ | |
"more information", "tell me more", "more details", "give me more details", | |
"i need more details", "can you provide more details", "additional details", | |
"further information", "expand on that", "explain further", "elaborate more", | |
"more specifics", "i want to know more", "could you elaborate", "need more info", | |
"provide more details", "detail it further", "in-depth information", "break it down further", | |
"further explanation", "property 1", "property1", "first property", "about the 2nd", "regarding number 3" | |
] | |
if any(phrase in query.lower() for phrase in detail_phrases): | |
route_value = "detail" | |
if route_value not in {"search", "suggest", "detail", "general", "out_of_domain"}: | |
route_value = "general" | |
if route_value == "out_of_domain" and positive_signal: | |
route_value = "general" | |
if route_value == "out_of_domain": | |
route_value = "general" if positive_signal else "out_of_domain" | |
return {"route": route_value} | |
# ------------------------ Workflow Setup ------------------------ | |
workflow = StateGraph(state_schema=dict) | |
def route_query(state: dict) -> dict: | |
new_state = state.copy() | |
try: | |
new_state["route"] = determine_route.invoke(new_state.get("query", "")).get("route", "general") | |
print(new_state["route"]) | |
except Exception as e: | |
print(f"Routing error: {e}") | |
new_state["route"] = "general" | |
return new_state | |
def hybrid_extract(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["filters"] = extract_filters.invoke(new_state.get("query", "")).get("filters", {}) | |
return new_state | |
def search_faiss(state: dict) -> dict: | |
new_state = state.copy() | |
# Preserve previous properties until new ones are fetched: | |
new_state.setdefault("current_properties", state.get("current_properties", [])) | |
query_embedding = st_model.encode([state["query"]]) | |
_, indices = index.search(query_embedding.astype(np.float32), 5) | |
new_state["faiss_results"] = [docs[idx] for idx in indices[0] if idx < len(docs)] | |
return new_state | |
def apply_filters(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["final_results"] = apply_filters_partial(state["faiss_results"], state.get("filters", {})) | |
if(len(new_state["final_results"]) == 0): | |
new_state["response"] = "Sorry, There is no result found :(" | |
new_state["route"] = "general" | |
return new_state | |
def suggest_properties(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["suggestions"] = random.sample(docs, 5) | |
# Explicitly update current_properties only when new listings are fetched | |
new_state["current_properties"] = new_state["suggestions"] | |
if(len(new_state["suggestions"]) == 0): | |
new_state["response"] = "Sorry, There is no result found :(" | |
new_state["route"] = "general" | |
return new_state | |
def handle_out_of_domain(state: dict) -> dict: | |
new_state = state.copy() | |
new_state["response"] = "I only handle real estate inquiries. Please ask a question related to properties." | |
return new_state | |
def generate_response(state: dict) -> dict: | |
new_state = state.copy() | |
messages = [] | |
# Add the general query prompt. | |
messages.append({"role": "system", "content": general_query_prompt}) | |
# For detail queries (specific property queries), add extra instructions. | |
if new_state.get("route", "general") == "detail": | |
messages.append({ | |
"role": "system", | |
"content": ( | |
"The user is asking about a specific property from the numbered list below. " | |
"Properties are listed as 1, 2, 3, etc. Use ONLY the corresponding property details. " | |
"For example, if the user says 'property 2', respond using only the details from the second entry. Never invent data." | |
) | |
}) | |
if new_state.get("current_properties"): | |
# Format properties with indices starting at 1 | |
property_context = format_property_data_with_indices(new_state["current_properties"]) | |
messages.append({"role": "system", "content": "Available Properties:\n" + property_context}) | |
messages.append({"role": "system", "content": "When responding, use only the provided property details."}) | |
# Add conversation history | |
# Truncate conversation history (last 6 exchanges) | |
truncated_history = state.get("messages", [])[-12:] # Last 6 user+assistant pairs | |
for msg in truncated_history: | |
messages.append({"role": msg["role"], "content": msg["content"]}) | |
connection_id = state.get("connection_id") | |
loop = state.get("loop") | |
if connection_id and loop: | |
print("Using WebSocket streaming") | |
callback_manager = [WebSocketStreamingCallbackHandler(connection_id, loop)] | |
_ = llm.invoke( | |
messages, | |
config={"callbacks": callback_manager} | |
) | |
new_state["response"] = "" | |
else: | |
callback_manager = [StreamingStdOutCallbackHandler()] | |
response = llm.invoke( | |
messages, | |
config={"callbacks": callback_manager} | |
) | |
new_state["response"] = response.content if isinstance(response, AIMessage) else str(response) | |
return new_state | |
def format_property_data_with_indices(properties: list) -> str: | |
formatted = [] | |
for idx, prop in enumerate(properties, 1): | |
cost = prop.get("totalCosts", "N/A") | |
cost_str = f"{cost:,}" if isinstance(cost, (int, float)) else cost | |
formatted.append( | |
f"{idx}. Type: {prop['propertyType']}, Cost: AED {cost_str}, " | |
f"Size: {prop.get('propertySize', 'N/A')}, Amenities: {', '.join(prop.get('amenities', []))}, " | |
f"Rental Yield: {prop.get('expectedRentalYield', 'N/A')}, " | |
f"Ownership: {prop.get('ownershipType', 'N/A')}" | |
) | |
return "\n".join(formatted) | |
def format_final_response(state: dict) -> dict: | |
new_state = state.copy() | |
if state.get("route") in ["search", "suggest"]: | |
if "final_results" in state: | |
new_state["current_properties"] = state["final_results"] | |
elif "suggestions" in state: | |
new_state["current_properties"] = state["suggestions"] | |
elif "current_properties" in new_state: | |
new_state["current_properties"] = state["current_properties"] | |
if state.get("route") in ["search", "suggest"] and new_state.get("current_properties"): | |
formatted = structured_property_data(state=new_state) | |
aggregated_response = "Here are the property details:\n" + "\n".join(formatted) | |
connection_id = state.get("connection_id") | |
loop = state.get("loop") | |
if connection_id and loop: | |
import time | |
tokens = aggregated_response.split(" ") | |
for token in tokens: | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(connection_id, token + " "), | |
loop | |
) | |
time.sleep(0.05) | |
new_state["response"] = "" | |
else: | |
new_state["response"] = aggregated_response | |
elif "response" in new_state: | |
connection_id = state.get("connection_id") | |
loop = state.get("loop") | |
if connection_id and loop: | |
import time | |
tokens = str(new_state["response"]).split(" ") | |
for token in tokens: | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(connection_id, token + " "), | |
loop | |
) | |
time.sleep(0.05) | |
new_state["response"] = str(new_state["response"]) | |
return new_state | |
nodes = [ | |
("route_query", route_query), | |
("hybrid_extract", hybrid_extract), | |
("faiss_search", search_faiss), | |
("apply_filters", apply_filters), | |
("suggest_properties", suggest_properties), | |
("handle_out_of_domain", handle_out_of_domain), | |
("generate_response", generate_response), | |
("format_response", format_final_response) | |
] | |
for name, node in nodes: | |
workflow.add_node(name, node) | |
workflow.add_edge(START, "route_query") | |
workflow.add_conditional_edges( | |
"route_query", | |
lambda state: state.get("route", "general"), | |
{ | |
"search": "hybrid_extract", | |
"suggest": "suggest_properties", | |
"detail": "generate_response", | |
"general": "generate_response", | |
"out_of_domain": "handle_out_of_domain" | |
} | |
) | |
workflow.add_edge("hybrid_extract", "faiss_search") | |
workflow.add_edge("faiss_search", "apply_filters") | |
workflow.add_edge("apply_filters", "format_response") | |
workflow.add_edge("suggest_properties", "format_response") | |
workflow.add_edge("generate_response", "format_response") | |
workflow.add_edge("handle_out_of_domain", "format_response") | |
workflow.add_edge("format_response", END) | |
workflow_app = workflow.compile() | |
# ------------------------ Conversation Manager ------------------------ | |
class ConversationManager: | |
def __init__(self): | |
# Each connection gets its own conversation history and state. | |
self.conversation_history = [] | |
# current_properties stores the current property listing. | |
self.current_properties = [] | |
def _add_message(self, role: str, content: str): | |
self.conversation_history.append({ | |
"role": role, | |
"content": content, | |
"timestamp": datetime.now().isoformat() | |
}) | |
def process_query(self, query: str) -> str: | |
# For greeting messages, reset history/state. // post request | |
if query.strip().lower() in {"hi", "hello", "hey"}: | |
self.conversation_history = [] | |
self.current_properties = [] | |
greeting_response = "Hello! How can I assist you today with your real estate inquiries?" | |
self._add_message("assistant", greeting_response) | |
return greeting_response | |
try: | |
self._add_message("user", query) | |
initial_state = { | |
"messages": self.conversation_history.copy(), | |
"query": query, | |
"route": "general", | |
"filters": {}, | |
"current_properties": self.current_properties | |
} | |
for event in workflow_app.stream(initial_state, stream_mode="values"): | |
final_state = event | |
# Only update property listings if a new listing is fetched | |
# if 'final_results' in final_state: | |
# self.current_properties = final_state['final_results'] | |
# elif 'suggestions' in final_state: | |
# self.current_properties = final_state['suggestions'] | |
self.current_properties = final_state.get("current_properties", []) | |
if final_state.get("route") == "general": | |
response_text = final_state.get("response", "") | |
self._add_message("assistant", response_text) | |
return response_text | |
else: | |
response = final_state.get("response", "I couldn't process that request.") | |
self._add_message("assistant", response) | |
return response | |
except Exception as e: | |
print(f"Processing error: {e}") | |
return "Sorry, I encountered an error processing your request." | |
conversation_managers = {} | |
# ------------------------ FastAPI Backend with WebSockets ------------------------ | |
app = FastAPI() | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections = {} | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
connection_id = str(uuid.uuid4()) | |
self.active_connections[connection_id] = websocket | |
print(f"New connection: {connection_id}") | |
return connection_id | |
def disconnect(self, connection_id: str): | |
if connection_id in self.active_connections: | |
del self.active_connections[connection_id] | |
print(f"Disconnected: {connection_id}") | |
async def send_message(self, connection_id: str, message: str): | |
websocket = self.active_connections.get(connection_id) | |
if websocket: | |
await websocket.send_text(message) | |
manager_socket = ConnectionManager() | |
def stream_query(query: str, connection_id: str, loop): | |
conv_manager = conversation_managers.get(connection_id) | |
if conv_manager is None: | |
print(f"No conversation manager found for connection {connection_id}") | |
return | |
if query.strip().lower() in {"hi", "hello", "hey"}: | |
conv_manager.conversation_history = [] | |
conv_manager.current_properties = [] | |
greeting_response = "Hello! How can I assist you today with your real estate inquiries?" | |
conv_manager._add_message("assistant", greeting_response) | |
sendTokenViaSocket( | |
state={"connection_id": connection_id, "loop": loop}, | |
manager_socket=manager_socket, | |
message=greeting_response | |
) | |
# asyncio.run_coroutine_threadsafe( | |
# manager_socket.send_message(connection_id, greeting_response), | |
# loop | |
# ) | |
return | |
conv_manager._add_message("user", query) | |
initial_state = { | |
"messages": conv_manager.conversation_history.copy(), | |
"query": query, | |
"route": "general", | |
"filters": {}, | |
"current_properties": conv_manager.current_properties, | |
"connection_id": connection_id, | |
"loop": loop | |
} | |
# try: | |
# workflow_app.invoke(initial_state) | |
# except Exception as e: | |
# error_msg = f"Error processing query: {str(e)}" | |
# asyncio.run_coroutine_threadsafe( | |
# manager_socket.send_message(connection_id, error_msg), | |
# loop | |
# ) | |
try: | |
# Capture all states during execution | |
# final_state = None | |
# for event in workflow_app.stream(initial_state, stream_mode="values"): | |
# final_state = event | |
# # Update conversation manager with final state | |
# if final_state: | |
# conv_manager.current_properties = final_state.get("current_properties", []) | |
# if final_state.get("response"): | |
# conv_manager._add_message("assistant", final_state["response"]) | |
final_state = None | |
for event in workflow_app.stream(initial_state, stream_mode="values"): | |
final_state = event | |
if final_state: | |
# Always update current_properties from final state | |
conv_manager.current_properties = final_state.get("current_properties", []) | |
# Keep conversation history bounded | |
conv_manager.conversation_history = conv_manager.conversation_history[-12:] # Last 6 exchanges | |
except Exception as e: | |
error_msg = f"Error processing query: {str(e)}" | |
asyncio.run_coroutine_threadsafe( | |
manager_socket.send_message(connection_id, error_msg), | |
loop | |
) | |
async def websocket_endpoint(websocket: WebSocket): | |
connection_id = await manager_socket.connect(websocket) | |
# Each connection maintains its own conversation manager. | |
conversation_managers[connection_id] = ConversationManager() | |
try: | |
while True: | |
query = await websocket.receive_text() | |
loop = asyncio.get_event_loop() | |
threading.Thread( | |
target=stream_query, | |
args=(query, connection_id, loop), | |
daemon=True | |
).start() | |
except WebSocketDisconnect: | |
conv_manager = conversation_managers.get(connection_id) | |
if conv_manager: | |
filename = f"conversations/conversation_{connection_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
with open(filename, "w") as f: | |
json.dump(conv_manager.conversation_history, f, indent=4) | |
del conversation_managers[connection_id] | |
manager_socket.disconnect(connection_id) | |
async def post_query(query: str): | |
conv_manager = ConversationManager() | |
response = conv_manager.process_query(query) | |
return {"response": response} | |
model_url = "https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct-GGUF/resolve/main/qwen2.5-1.5b-instruct-q4_k_m.gguf" | |
async def async_download(): | |
import aiohttp | |
async with aiohttp.ClientSession() as session: | |
async with session.get(model_url) as response: | |
with open(model_path, "wb") as f: | |
while True: | |
chunk = await response.content.read(1024) | |
if not chunk: | |
break | |
f.write(chunk) | |
async def check_model_middleware(request: Request, call_next): | |
if not os.path.exists(model_path): | |
await async_download() | |
print("successfully downloaded") | |
else: | |
print("already downloaded") | |
return await call_next(request) | |
async def home(): | |
return PlainTextResponse("Space is running. Model ready!") | |
# async def clear_cache_periodically(seconds: int = 3600): | |
# while True: | |
# await asyncio.sleep(seconds) | |
# extract_filters.cache_clear() | |
# determine_route.cache_clear() | |
# ChatQwen.build_prompt.cache_clear() | |
# print("Cache cleared") | |
# @app.on_event("startup") | |
# async def startup_event(): | |
# background_tasks = BackgroundTasks() | |
# background_tasks.add_task(clear_cache_periodically, 3600) # Clear every hour |