from langchain.vectorstores import FAISS import math import os import pickle import uuid from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import numpy as np from langchain.docstore.base import AddableMixin, Docstore from langchain.docstore.document import Document from langchain.docstore.in_memory import InMemoryDocstore from langchain.embeddings.base import Embeddings from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance class MyFAISS(FAISS): def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch before filtering to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. Returns: List of Documents selected by maximal marginal relevance. """ _, indices = self.index.search( np.array([embedding], dtype=np.float32), fetch_k if filter is None else fetch_k * 2, ) if filter is not None: filtered_indices = [] for i in indices[0]: if i == -1: # This happens when not enough docs are returned. continue _id = self.index_to_docstore_id[i] doc = self.docstore.search(_id) if not isinstance(doc, Document): raise ValueError(f"Could not find document for id {_id}, got {doc}") print("metadata: " + str(doc.metadata)) print("filter: " + str(filter)) if any(filter_word in doc.metadata.get(key, '') for key, value in filter.items() for filter_word in value.split()): filtered_indices.append(i) indices = np.array([filtered_indices]) # -1 happens when not enough docs are returned. embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] mmr_selected = maximal_marginal_relevance( np.array([embedding], dtype=np.float32), embeddings, k=k, lambda_mult=lambda_mult, ) selected_indices = [indices[0][i] for i in mmr_selected] docs = [] for i in selected_indices: if i == -1: # This happens when not enough docs are returned. continue _id = self.index_to_docstore_id[i] doc = self.docstore.search(_id) if not isinstance(doc, Document): raise ValueError(f"Could not find document for id {_id}, got {doc}") docs.append(doc) return docs def max_marginal_relevance_search( self, query: str, k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch before filtering (if needed) to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. Returns: List of Documents selected by maximal marginal relevance. """ print("MMR search") embedding = self.embedding_function(query) docs = self.max_marginal_relevance_search_by_vector( embedding, k, fetch_k, lambda_mult=lambda_mult, filter=filter, **kwargs, ) return docs