kth-qa / magic /self_query_retriever.py
erseux's picture
new huggingface structure
history blame
4.98 kB
"""Retriever that generates and executes structured queries over its own data source.
This code is adapted from the original implementation in the LangChain repo,
but has been modified to work with the KTH QA system.
import re
from typing import Any, Dict, List, Optional, Type, cast
from pydantic import BaseModel, Field, root_validator
from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores import Pinecone, VectorStore
COURSE_PATTERN = r"\w{2,3}\d{3,4}\w?" # e.g. DD1315
def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor:
"""Get the translator class corresponding to the vector store class."""
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
Pinecone: PineconeTranslator
if vectorstore_cls not in BUILTIN_TRANSLATORS:
raise ValueError(
f"Self query retriever with Vector Store type {vectorstore_cls}"
f" not supported."
return BUILTIN_TRANSLATORS[vectorstore_cls]()
class SelfQueryRetriever(BaseRetriever, BaseModel):
"""Retriever that wraps around a vector store and uses an LLM to generate
the vector store queries."""
vectorstore: VectorStore
"""The underlying vector store from which documents will be retrieved."""
llm_chain: LLMChain
"""The LLMChain for generating the vector store queries."""
search_type: str = "similarity"
"""The search type to perform on the vector store."""
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass in to the vector store search."""
structured_query_translator: Visitor
"""Translator for turning internal query language into vectorstore search params."""
verbose: bool = False
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def validate_translator(cls, values: Dict) -> Dict:
"""Validate translator."""
if "structured_query_translator" not in values:
vectorstore_cls = values["vectorstore"].__class__
values["structured_query_translator"] = _get_builtin_translator(
return values
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
query: string to find relevant documents for
List of relevant documents
if re.findall(COURSE_PATTERN, query):
inputs = self.llm_chain.prep_inputs(query)
structured_query = cast(
StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs)
if self.verbose:
print("Found course pattern in query, using structured query:")
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
search_kwargs = {**self.search_kwargs, **new_kwargs}
search_kwargs = self.search_kwargs
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
def from_llm(
llm: BaseLanguageModel,
vectorstore: VectorStore,
document_contents: str,
metadata_field_info: List[AttributeInfo],
structured_query_translator: Optional[Visitor] = None,
chain_kwargs: Optional[Dict] = None,
**kwargs: Any,
) -> "SelfQueryRetriever":
if structured_query_translator is None:
structured_query_translator = _get_builtin_translator(vectorstore.__class__)
chain_kwargs = chain_kwargs or {}
if "allowed_comparators" not in chain_kwargs:
] = structured_query_translator.allowed_comparators
if "allowed_operators" not in chain_kwargs:
] = structured_query_translator.allowed_operators
llm_chain = load_query_constructor_chain(
llm, document_contents, metadata_field_info, **chain_kwargs
return cls(