|
|
|
from typing import List |
|
import numpy as np |
|
from chromadb import Client |
|
import openai |
|
from config.constants import DEEPINFRA_MODEL_TAG, DEEPINFRA_ENDPOINT_URL |
|
import os |
|
DEEPINFRA_API_KEY = os.getenv('DEEPINFRA_API_KEY') |
|
|
|
class SearchService: |
|
def __init__(self): |
|
self.client = Client() |
|
self.collection_name = "listing_collection" |
|
self.collection = self.client.create_collection( |
|
name=self.collection_name, |
|
metadata={ |
|
'description': 'real_estate_listing', |
|
"hnsw:construction_ef": 64, |
|
"hnsw:M": 32, |
|
"hnsw:search_ef": 32, |
|
}, |
|
embedding_function=None, |
|
) |
|
|
|
def ingest_data(self, embd_id): |
|
|
|
embeddings = embd_id[:, 1:].astype(float) |
|
original_ids = [f"PTFS{num}" for num in embd_id[:, 0].astype('int64')] |
|
ids = [str(i) for i in range(len(original_ids))] |
|
self.collection.add( |
|
ids=ids, |
|
embeddings=embeddings, |
|
metadatas=[{"original_id": id} for id in original_ids], |
|
) |
|
|
|
def search(self, query: str) -> List[str]: |
|
|
|
openai.api_key = DEEPINFRA_API_KEY |
|
openai.api_base = DEEPINFRA_ENDPOINT_URL |
|
|
|
embeddings = openai.Embedding.create(input=query, model=DEEPINFRA_MODEL_TAG, encoding_format="float") |
|
query_embedding = embeddings.data[0].embedding |
|
|
|
|
|
results = self.collection.query(np.array([query_embedding]), n_results=10) |
|
|
|
|
|
original_ids = [metadata["original_id"] for metadata in results["metadatas"][0]] |
|
return original_ids |
|
|