|
"""Retriever that generates and executes structured queries over its own data source. |
|
|
|
NOTE: This code is adapted from the original implementation in the LangChain repo, |
|
but has been modified to work with the KTH QA system. |
|
|
|
""" |
|
|
|
from langchain.vectorstores import Pinecone, VectorStore |
|
from langchain.schema import BaseRetriever, Document |
|
from langchain.retrievers.self_query.pinecone import PineconeTranslator |
|
from langchain.chains.query_constructor.schema import AttributeInfo |
|
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor |
|
from langchain.chains.query_constructor.base import load_query_constructor_chain |
|
from langchain.base_language import BaseLanguageModel |
|
from langchain import LLMChain |
|
from pydantic import BaseModel, Field, root_validator |
|
import re |
|
from typing import Any, Dict, List, Optional, Type, cast |
|
import logging |
|
logger = logging.getLogger() |
|
|
|
|
|
COURSE_PATTERN = r"[a-zA-Z]{2,3}\d{3,4}\w?" |
|
|
|
|
|
def make_uppercase(matchobj): |
|
return matchobj.group(0).upper() |
|
|
|
|
|
def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor: |
|
"""Get the translator class corresponding to the vector store class.""" |
|
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = { |
|
Pinecone: PineconeTranslator |
|
} |
|
if vectorstore_cls not in BUILTIN_TRANSLATORS: |
|
raise ValueError( |
|
f"Self query retriever with Vector Store type {vectorstore_cls}" |
|
f" not supported." |
|
) |
|
return BUILTIN_TRANSLATORS[vectorstore_cls]() |
|
|
|
|
|
class SelfQueryRetriever(BaseRetriever, BaseModel): |
|
"""Retriever that wraps around a vector store and uses an LLM to generate |
|
the vector store queries.""" |
|
|
|
vectorstore: VectorStore |
|
"""The underlying vector store from which documents will be retrieved.""" |
|
llm_chain: LLMChain |
|
"""The LLMChain for generating the vector store queries.""" |
|
search_type: str = "similarity" |
|
"""The search type to perform on the vector store.""" |
|
search_kwargs: dict = Field(default_factory=dict) |
|
"""Keyword arguments to pass in to the vector store search.""" |
|
structured_query_translator: Visitor |
|
"""Translator for turning internal query language into vectorstore search params.""" |
|
verbose: bool = False |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
arbitrary_types_allowed = True |
|
|
|
@root_validator(pre=True) |
|
def validate_translator(cls, values: Dict) -> Dict: |
|
"""Validate translator.""" |
|
if "structured_query_translator" not in values: |
|
vectorstore_cls = values["vectorstore"].__class__ |
|
values["structured_query_translator"] = _get_builtin_translator( |
|
vectorstore_cls |
|
) |
|
return values |
|
|
|
def get_relevant_documents(self, query: str) -> List[Document]: |
|
"""Get documents relevant for a query. |
|
|
|
Args: |
|
query: string to find relevant documents for |
|
|
|
Returns: |
|
List of relevant documents |
|
""" |
|
if re.findall(COURSE_PATTERN, query): |
|
query = re.sub(COURSE_PATTERN, make_uppercase, query) |
|
inputs = self.llm_chain.prep_inputs(query) |
|
structured_query = cast( |
|
StructuredQuery, self.llm_chain.predict_and_parse( |
|
callbacks=None, **inputs) |
|
) |
|
if self.verbose: |
|
logger.info( |
|
"Found course pattern in query, using structured query:") |
|
logger.info(structured_query) |
|
new_query, new_kwargs = self.structured_query_translator.visit_structured_query( |
|
structured_query |
|
) |
|
search_kwargs = {**self.search_kwargs, **new_kwargs} |
|
else: |
|
search_kwargs = self.search_kwargs |
|
docs = self.vectorstore.search( |
|
query, self.search_type, **search_kwargs) |
|
return docs |
|
|
|
async def aget_relevant_documents(self, query: str) -> List[Document]: |
|
raise NotImplementedError |
|
|
|
@classmethod |
|
def from_llm( |
|
cls, |
|
llm: BaseLanguageModel, |
|
vectorstore: VectorStore, |
|
document_contents: str, |
|
metadata_field_info: List[AttributeInfo], |
|
structured_query_translator: Optional[Visitor] = None, |
|
chain_kwargs: Optional[Dict] = None, |
|
**kwargs: Any, |
|
) -> "SelfQueryRetriever": |
|
if structured_query_translator is None: |
|
structured_query_translator = _get_builtin_translator( |
|
vectorstore.__class__) |
|
chain_kwargs = chain_kwargs or {} |
|
if "allowed_comparators" not in chain_kwargs: |
|
chain_kwargs[ |
|
"allowed_comparators" |
|
] = structured_query_translator.allowed_comparators |
|
if "allowed_operators" not in chain_kwargs: |
|
chain_kwargs[ |
|
"allowed_operators" |
|
] = structured_query_translator.allowed_operators |
|
llm_chain = load_query_constructor_chain( |
|
llm, document_contents, metadata_field_info, **chain_kwargs |
|
) |
|
return cls( |
|
llm_chain=llm_chain, |
|
vectorstore=vectorstore, |
|
structured_query_translator=structured_query_translator, |
|
**kwargs, |
|
) |
|
|