Spaces:
Running
Running
""" | |
modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py | |
""" | |
from functools import partial | |
from typing import Callable | |
from typing import Iterable | |
from typing import Optional | |
from langchain.schema import Document | |
from langchain.tools import Tool | |
from langchain_core.callbacks.manager import Callbacks | |
from langchain_core.pydantic_v1 import BaseModel | |
from langchain_core.pydantic_v1 import Field | |
from langchain_core.retrievers import BaseRetriever | |
class RetrieverInput(BaseModel): | |
"""Input to the retriever.""" | |
query: str = Field(description="query to look up in retriever") | |
def _get_relevant_documents( | |
query: str, | |
retriever: BaseRetriever, | |
format_docs: Callable[[Iterable[Document]], str], | |
callbacks: Callbacks = None, | |
) -> str: | |
docs = retriever.get_relevant_documents(query, callbacks=callbacks) | |
return format_docs(docs) | |
async def _aget_relevant_documents( | |
query: str, | |
retriever: BaseRetriever, | |
format_docs: Callable[[Iterable[Document]], str], | |
callbacks: Callbacks = None, | |
) -> str: | |
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks) | |
return format_docs(docs) | |
def get_retriever_tool( | |
retriever: BaseRetriever, | |
name: str, | |
description: str, | |
format_docs: Callable[[Iterable[Document]], str], | |
) -> Tool: | |
"""Create a tool to do retrieval of documents. | |
Args: | |
retriever: The retriever to use for the retrieval | |
name: The name for the tool. This will be passed to the language model, | |
so should be unique and somewhat descriptive. | |
description: The description for the tool. This will be passed to the language | |
model, so should be descriptive. | |
format_docs: A function to turn an iterable of docs into a string. | |
Returns: | |
Tool class to pass to an agent | |
""" | |
func = partial( | |
_get_relevant_documents, | |
retriever=retriever, | |
format_docs=format_docs, | |
) | |
afunc = partial( | |
_aget_relevant_documents, | |
retriever=retriever, | |
format_docs=format_docs, | |
) | |
return Tool( | |
name=name, | |
description=description, | |
func=func, | |
coroutine=afunc, | |
args_schema=RetrieverInput, | |
) | |