File size: 2,301 Bytes
a01d550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
"""

from functools import partial
from typing import Callable
from typing import Iterable
from typing import Optional

from langchain.schema import Document
from langchain.tools import Tool
from langchain_core.callbacks.manager import Callbacks
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever


class RetrieverInput(BaseModel):
    """Input to the retriever."""
    query: str = Field(description="query to look up in retriever")


def _get_relevant_documents(
    query: str,
    retriever: BaseRetriever,
    format_docs: Callable[[Iterable[Document]], str],
    callbacks: Callbacks = None,
) -> str:
    docs = retriever.get_relevant_documents(query, callbacks=callbacks)
    return format_docs(docs)


async def _aget_relevant_documents(
    query: str,
    retriever: BaseRetriever,
    format_docs: Callable[[Iterable[Document]], str],
    callbacks: Callbacks = None,
) -> str:
    docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
    return format_docs(docs)


def get_retriever_tool(
    retriever: BaseRetriever,
    name: str,
    description: str,
    format_docs: Callable[[Iterable[Document]], str],
) -> Tool:

    """Create a tool to do retrieval of documents.

    Args:
        retriever: The retriever to use for the retrieval
        name: The name for the tool. This will be passed to the language model,
            so should be unique and somewhat descriptive.
        description: The description for the tool. This will be passed to the language
            model, so should be descriptive.
        format_docs: A function to turn an iterable of docs into a string.

        Returns:
            Tool class to pass to an agent
    """
    func = partial(
        _get_relevant_documents,
        retriever=retriever,
        format_docs=format_docs,
    )
    afunc = partial(
        _aget_relevant_documents,
        retriever=retriever,
        format_docs=format_docs,
    )
    return Tool(
        name=name,
        description=description,
        func=func,
        coroutine=afunc,
        args_schema=RetrieverInput,
    )