Spaces:
Runtime error
Runtime error
File size: 3,082 Bytes
58d33f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
"""Tools for interacting with vectorstores."""
import json
from typing import Any, Dict
from pydantic import BaseModel, Field
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import VectorDBQA
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI
from langchain.tools.base import BaseTool
from langchain.vectorstores.base import VectorStore
class BaseVectorStoreTool(BaseModel):
"""Base class for tools that use a VectorStore."""
vectorstore: VectorStore = Field(exclude=True)
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0))
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:
values["description"] = values["template"].format(name=values["name"])
return values
class VectorStoreQATool(BaseVectorStoreTool, BaseTool):
"""Tool for the VectorDBQA chain. To be initialized with name and chain."""
@staticmethod
def get_description(name: str, description: str) -> str:
template: str = (
"Useful for when you need to answer questions about {name}. "
"Whenever you need information about {description} "
"you should ALWAYS use this. "
"Input should be a fully formed question."
)
return template.format(name=name, description=description)
def _run(self, query: str) -> str:
"""Use the tool."""
chain = VectorDBQA.from_chain_type(self.llm, vectorstore=self.vectorstore)
return chain.run(query)
async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("VectorDBQATool does not support async")
class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
"""Tool for the VectorDBQAWithSources chain."""
@staticmethod
def get_description(name: str, description: str) -> str:
template: str = (
"Useful for when you need to answer questions about {name} and the sources "
"used to construct the answer. "
"Whenever you need information about {description} "
"you should ALWAYS use this. "
" Input should be a fully formed question. "
"Output is a json serialized dictionary with keys `answer` and `sources`. "
"Only use this tool if the user explicitly asks for sources."
)
return template.format(name=name, description=description)
def _run(self, query: str) -> str:
"""Use the tool."""
chain = VectorDBQAWithSourcesChain.from_chain_type(
self.llm, vectorstore=self.vectorstore
)
return json.dumps(chain({chain.question_key: query}, return_only_outputs=True))
async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("VectorDBQATool does not support async")
|