import chromadb import sqlite3 import hashlib import pandas as pd from sentence_transformers import SentenceTransformer #--- Initialize ChromaDB and SentenceTransformer --- SCHEMA_DESCRIPTIONS = { "restaurants": """Table restaurants contains restaurant details: - id: unique identifier - name: restaurant name - cuisine: type of cuisine - location: area or neighborhood - seating_capacity: total seats - rating: average rating - address: full address - contact: phone or email - price_range: price category - special_features: amenities or highlights""", "tables": """Table tables contains table details: - id: unique identifier - restaurant_id: links to restaurants.id - capacity: number of seats (default 4)""", "slots": """Table slots contains reservation time slots: - id: unique identifier - table_id: links to tables.id - date: reservation date - hour: reservation hour - is_reserved: 0=available, 1=booked""" } class SchemaVectorDB: def __init__(self): self.client = chromadb.Client() self.collection = self.client.get_or_create_collection("schema") self.model = SentenceTransformer('all-MiniLM-L6-v2') for idx, (name, desc) in enumerate(SCHEMA_DESCRIPTIONS.items()): self.collection.add(ids=str(idx), documents=desc, metadatas={"name": name}) def get_relevant_schema(self, query, k=2): query_embedding = self.model.encode(query).tolist() results = self.collection.query(query_embeddings=[query_embedding], n_results=k) # results['metadatas'] is a list of lists: [[{...}, {...}], ...] # We only have one query, so grab the first list metadatas = results['metadatas'][0] if results['metadatas'] else [] return [m['name'] for m in metadatas if m and 'name' in m] class FullVectorDB: def __init__(self): self.client = chromadb.PersistentClient(path="db/chroma") self.model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/data/hf_cache/transformers") # Get existing collections or create if not exist self.restaurants_col = self.client.get_or_create_collection("restaurants") self.tables_col = self.client.get_or_create_collection("tables") self.slots_col = self.client.get_or_create_collection("slots") # Initialize only if collections are empty if len(self.restaurants_col.get()['ids']) == 0: self._initialize_collections() def _row_to_text(self, row): return ' '.join(str(v) for v in row.values if pd.notnull(v)) def _row_hash(self, row): return hashlib.sha256(str(row.values).encode()).hexdigest() def _initialize_collections(self): conn = sqlite3.connect("db/restaurant_reservation.db") # Create external changelog table conn.execute(""" CREATE TABLE IF NOT EXISTS chroma_changelog ( id INTEGER PRIMARY KEY, table_name TEXT, record_id INTEGER, content_hash TEXT, UNIQUE(table_name, record_id) ) """) conn.commit() # Process tables self._process_table(conn, "restaurants", self.restaurants_col) self._process_table(conn, "tables", self.tables_col) self._process_table(conn, "slots", self.slots_col) conn.close() def _process_table(self, conn, table_name, collection): # Get existing records from Chroma existing_ids = set(collection.get()['ids']) # Get all records from SQLite with hash df = pd.read_sql(f"SELECT * FROM {table_name}", conn) # Process each row for _, row in df.iterrows(): chroma_id = f"{table_name}_{row['id']}" current_hash = self._row_hash(row) # Check if exists in changelog changelog = pd.read_sql(f""" SELECT content_hash FROM chroma_changelog WHERE table_name = ? AND record_id = ? """, conn, params=(table_name, row['id'])) # Skip if hash matches if not changelog.empty and changelog.iloc[0]['content_hash'] == current_hash: continue # Generate embedding embedding = self.model.encode(self._row_to_text(row)) # Update Chroma collection.upsert( ids=[chroma_id], embeddings=[embedding.tolist()], metadatas=[row.to_dict()] ) # Update changelog conn.execute(""" INSERT OR REPLACE INTO chroma_changelog (table_name, record_id, content_hash) VALUES (?, ?, ?) """, (table_name, row['id'], current_hash)) conn.commit() def semantic_search(self, query, collection_name, k=5): query_embedding = self.model.encode(query).tolist() collection = getattr(self, f"{collection_name}_col") results = collection.query( query_embeddings=[query_embedding], n_results=k, include=["metadatas"] ) return results['metadatas'][0]