hf-legisqa / retriever_tools.py
gabrielaltay's picture
update
a01d550
raw
history blame
2.3 kB
"""
modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
"""
from functools import partial
from typing import Callable
from typing import Iterable
from typing import Optional
from langchain.schema import Document
from langchain.tools import Tool
from langchain_core.callbacks.manager import Callbacks
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
class RetrieverInput(BaseModel):
"""Input to the retriever."""
query: str = Field(description="query to look up in retriever")
def _get_relevant_documents(
query: str,
retriever: BaseRetriever,
format_docs: Callable[[Iterable[Document]], str],
callbacks: Callbacks = None,
) -> str:
docs = retriever.get_relevant_documents(query, callbacks=callbacks)
return format_docs(docs)
async def _aget_relevant_documents(
query: str,
retriever: BaseRetriever,
format_docs: Callable[[Iterable[Document]], str],
callbacks: Callbacks = None,
) -> str:
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
return format_docs(docs)
def get_retriever_tool(
retriever: BaseRetriever,
name: str,
description: str,
format_docs: Callable[[Iterable[Document]], str],
) -> Tool:
"""Create a tool to do retrieval of documents.
Args:
retriever: The retriever to use for the retrieval
name: The name for the tool. This will be passed to the language model,
so should be unique and somewhat descriptive.
description: The description for the tool. This will be passed to the language
model, so should be descriptive.
format_docs: A function to turn an iterable of docs into a string.
Returns:
Tool class to pass to an agent
"""
func = partial(
_get_relevant_documents,
retriever=retriever,
format_docs=format_docs,
)
afunc = partial(
_aget_relevant_documents,
retriever=retriever,
format_docs=format_docs,
)
return Tool(
name=name,
description=description,
func=func,
coroutine=afunc,
args_schema=RetrieverInput,
)