File size: 5,256 Bytes
8cb8290
 
d92f48e
8cb8290
 
 
 
d92f48e
 
 
 
 
 
 
 
 
8cb8290
 
d92f48e
 
8cb8290
 
d92f48e
8cb8290
d92f48e
 
 
8cb8290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d92f48e
8cb8290
 
d92f48e
 
8cb8290
 
d92f48e
 
 
8cb8290
 
 
 
 
 
d92f48e
 
8cb8290
ecb8959
 
 
8cb8290
 
 
 
 
 
 
 
 
 
 
 
 
d92f48e
 
8cb8290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Retriever that generates and executes structured queries over its own data source.

NOTE: This code is adapted from the original implementation in the LangChain repo,
but has been modified to work with the KTH QA system.

"""

from langchain.vectorstores import Pinecone, VectorStore
from langchain.schema import BaseRetriever, Document
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.base_language import BaseLanguageModel
from langchain import LLMChain
from pydantic import BaseModel, Field, root_validator
import re
from typing import Any, Dict, List, Optional, Type, cast
import logging
logger = logging.getLogger()


COURSE_PATTERN = r"[a-zA-Z]{2,3}\d{3,4}\w?"  # e.g. DD1315


def make_uppercase(matchobj):
    return matchobj.group(0).upper()


def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor:
    """Get the translator class corresponding to the vector store class."""
    BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
        Pinecone: PineconeTranslator
    }
    if vectorstore_cls not in BUILTIN_TRANSLATORS:
        raise ValueError(
            f"Self query retriever with Vector Store type {vectorstore_cls}"
            f" not supported."
        )
    return BUILTIN_TRANSLATORS[vectorstore_cls]()


class SelfQueryRetriever(BaseRetriever, BaseModel):
    """Retriever that wraps around a vector store and uses an LLM to generate
    the vector store queries."""

    vectorstore: VectorStore
    """The underlying vector store from which documents will be retrieved."""
    llm_chain: LLMChain
    """The LLMChain for generating the vector store queries."""
    search_type: str = "similarity"
    """The search type to perform on the vector store."""
    search_kwargs: dict = Field(default_factory=dict)
    """Keyword arguments to pass in to the vector store search."""
    structured_query_translator: Visitor
    """Translator for turning internal query language into vectorstore search params."""
    verbose: bool = False

    class Config:
        """Configuration for this pydantic object."""

        arbitrary_types_allowed = True

    @root_validator(pre=True)
    def validate_translator(cls, values: Dict) -> Dict:
        """Validate translator."""
        if "structured_query_translator" not in values:
            vectorstore_cls = values["vectorstore"].__class__
            values["structured_query_translator"] = _get_builtin_translator(
                vectorstore_cls
            )
        return values

    def get_relevant_documents(self, query: str) -> List[Document]:
        """Get documents relevant for a query.

        Args:
            query: string to find relevant documents for

        Returns:
            List of relevant documents
        """
        if re.findall(COURSE_PATTERN, query):
            query = re.sub(COURSE_PATTERN, make_uppercase, query)
            inputs = self.llm_chain.prep_inputs(query)
            structured_query = cast(
                StructuredQuery, self.llm_chain.predict_and_parse(
                    callbacks=None, **inputs)
            )
            if self.verbose:
                logger.info(
                    "Found course pattern in query, using structured query:")
                logger.info(structured_query)
            new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
                structured_query
            )
            search_kwargs = {**self.search_kwargs, **new_kwargs}
        else:
            search_kwargs = self.search_kwargs
        docs = self.vectorstore.search(
            query, self.search_type, **search_kwargs)
        return docs
    
    async def aget_relevant_documents(self, query: str) -> List[Document]:
        raise NotImplementedError

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        vectorstore: VectorStore,
        document_contents: str,
        metadata_field_info: List[AttributeInfo],
        structured_query_translator: Optional[Visitor] = None,
        chain_kwargs: Optional[Dict] = None,
        **kwargs: Any,
    ) -> "SelfQueryRetriever":
        if structured_query_translator is None:
            structured_query_translator = _get_builtin_translator(
                vectorstore.__class__)
        chain_kwargs = chain_kwargs or {}
        if "allowed_comparators" not in chain_kwargs:
            chain_kwargs[
                "allowed_comparators"
            ] = structured_query_translator.allowed_comparators
        if "allowed_operators" not in chain_kwargs:
            chain_kwargs[
                "allowed_operators"
            ] = structured_query_translator.allowed_operators
        llm_chain = load_query_constructor_chain(
            llm, document_contents, metadata_field_info, **chain_kwargs
        )
        return cls(
            llm_chain=llm_chain,
            vectorstore=vectorstore,
            structured_query_translator=structured_query_translator,
            **kwargs,
        )