|
import logging |
|
import inspect |
|
from typing import Dict, Any, Optional, List, Tuple |
|
|
|
|
|
from langchain.callbacks.manager import ( |
|
AsyncCallbackManagerForChainRun, |
|
CallbackManagerForChainRun, |
|
) |
|
from langchain.embeddings.base import Embeddings |
|
from langchain.callbacks.manager import Callbacks |
|
from langchain.schema.prompt_template import format_document |
|
from langchain.docstore.document import Document |
|
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain |
|
from langchain.vectorstores.myscale import MyScale, MyScaleSettings |
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain |
|
|
|
from langchain_experimental.sql.vector_sql import VectorSQLOutputParser |
|
|
|
logger = logging.getLogger() |
|
|
|
|
|
class MyScaleWithoutMetadataJson(MyScale): |
|
def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None: |
|
super().__init__(embedding, config, **kwargs) |
|
self.must_have_cols: List[str] = must_have_cols |
|
|
|
def _build_qstr( |
|
self, q_emb: List[float], topk: int, where_str: Optional[str] = None |
|
) -> str: |
|
q_emb_str = ",".join(map(str, q_emb)) |
|
if where_str: |
|
where_str = f"PREWHERE {where_str}" |
|
else: |
|
where_str = "" |
|
|
|
q_str = f""" |
|
SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)} |
|
FROM {self.config.database}.{self.config.table} |
|
{where_str} |
|
ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}]) |
|
AS dist {self.dist_order} |
|
LIMIT {topk} |
|
""" |
|
return q_str |
|
|
|
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]: |
|
q_str = self._build_qstr(embedding, k, where_str) |
|
try: |
|
return [ |
|
Document( |
|
page_content=r[self.config.column_map["text"]], |
|
metadata={k: r[k] for k in self.must_have_cols}, |
|
) |
|
for r in self.client.query(q_str).named_results() |
|
] |
|
except Exception as e: |
|
logger.error( |
|
f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") |
|
return [] |
|
|
|
|
|
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser): |
|
"""Based on VectorSQLOutputParser |
|
It also modify the SQL to get all columns |
|
""" |
|
must_have_columns: List[str] |
|
|
|
@property |
|
def _type(self) -> str: |
|
return "vector_sql_retrieve_custom" |
|
|
|
def parse(self, text: str) -> Dict[str, Any]: |
|
text = text.strip() |
|
start = text.upper().find("SELECT") |
|
if start >= 0: |
|
end = text.upper().find("FROM") |
|
text = text.replace( |
|
text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns)) |
|
return super().parse(text) |
|
|
|
|
|
class ArXivStuffDocumentChain(StuffDocumentsChain): |
|
"""Combine arxiv documents with PDF reference number""" |
|
|
|
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: |
|
"""Construct inputs from kwargs and docs. |
|
|
|
Format and the join all the documents together into one input with name |
|
`self.document_variable_name`. The pluck any additional variables |
|
from **kwargs. |
|
|
|
Args: |
|
docs: List of documents to format and then join into single input |
|
**kwargs: additional inputs to chain, will pluck any other required |
|
arguments from here. |
|
|
|
Returns: |
|
dictionary of inputs to LLMChain |
|
""" |
|
|
|
doc_strings = [] |
|
for doc_id, doc in enumerate(docs): |
|
|
|
doc.metadata.update({'ref_id': doc_id}) |
|
doc.page_content = doc.page_content.replace('\n', ' ') |
|
doc_strings.append(format_document(doc, self.document_prompt)) |
|
|
|
inputs = { |
|
k: v |
|
for k, v in kwargs.items() |
|
if k in self.llm_chain.prompt.input_variables |
|
} |
|
inputs[self.document_variable_name] = self.document_separator.join( |
|
doc_strings) |
|
return inputs |
|
|
|
def combine_docs( |
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any |
|
) -> Tuple[str, dict]: |
|
"""Stuff all documents into one prompt and pass to LLM. |
|
|
|
Args: |
|
docs: List of documents to join together into one variable |
|
callbacks: Optional callbacks to pass along |
|
**kwargs: additional parameters to use to get inputs to LLMChain. |
|
|
|
Returns: |
|
The first element returned is the single string output. The second |
|
element returned is a dictionary of other keys to return. |
|
""" |
|
inputs = self._get_inputs(docs, **kwargs) |
|
|
|
output = self.llm_chain.predict(callbacks=callbacks, **inputs) |
|
return output, {} |
|
|
|
@property |
|
def _chain_type(self) -> str: |
|
return "referenced_stuff_documents_chain" |
|
|
|
|
|
class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain): |
|
"""QA with source chain for Chat ArXiv app with references |
|
|
|
This chain will automatically assign reference number to the article, |
|
Then parse it back to titles or anything else. |
|
""" |
|
|
|
def _call( |
|
self, |
|
inputs: Dict[str, Any], |
|
run_manager: Optional[CallbackManagerForChainRun] = None, |
|
) -> Dict[str, str]: |
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() |
|
accepts_run_manager = ( |
|
"run_manager" in inspect.signature(self._get_docs).parameters |
|
) |
|
if accepts_run_manager: |
|
docs = self._get_docs(inputs, run_manager=_run_manager) |
|
else: |
|
docs = self._get_docs(inputs) |
|
|
|
answer = self.combine_documents_chain.run( |
|
input_documents=docs, callbacks=_run_manager.get_child(), **inputs |
|
) |
|
|
|
sources = [] |
|
ref_cnt = 1 |
|
for d in docs: |
|
ref_id = d.metadata['ref_id'] |
|
if f"Doc #{ref_id}" in answer: |
|
answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}") |
|
if f"#{ref_id}" in answer: |
|
title = d.metadata['title'].replace('\n', '') |
|
d.metadata['ref_id'] = ref_cnt |
|
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]") |
|
sources.append(d) |
|
ref_cnt += 1 |
|
|
|
result: Dict[str, Any] = { |
|
self.answer_key: answer, |
|
self.sources_answer_key: sources, |
|
} |
|
if self.return_source_documents: |
|
result["source_documents"] = docs |
|
return result |
|
|
|
async def _acall( |
|
self, |
|
inputs: Dict[str, Any], |
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, |
|
) -> Dict[str, Any]: |
|
raise NotImplementedError |
|
|
|
@property |
|
def _chain_type(self) -> str: |
|
return "arxiv_qa_with_sources_chain" |
|
|