""" /************************************************************************* * * CONFIDENTIAL * __________________ * * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC * All Rights Reserved * * Author : Theekshana Samaradiwakara * Description :Python Backend API to chat with private data * CreatedDate : 14/11/2023 * LastModifiedDate : 21/03/2024 *************************************************************************/ """ import asyncio import logging from typing import List, Optional, Sequence from langchain_core.callbacks import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts.prompt import PromptTemplate from langchain_core.retrievers import BaseRetriever from langchain.chains.llm import LLMChain import numpy as np import pandas as pd logger = logging.getLogger(__name__) from prompts import MULTY_QUERY_PROMPT class LineListOutputParser(BaseOutputParser[List[str]]): """Output parser for a list of lines.""" def parse(self, text: str) -> List[str]: lines = text.strip().split("\n") return lines # Default prompt # DEFAULT_QUERY_PROMPT = PromptTemplate( # input_variables=["question"], # template="""You are an AI language model assistant. Your task is # to generate 3 different versions of the given user # question to retrieve relevant documents from a vector database. # By generating multiple perspectives on the user question, # your goal is to help the user overcome some of the limitations # of distance-based similarity search. Provide these alternative # questions separated by newlines. Original question: {question}""", # ) def _unique_documents(documents: Sequence[Document]) -> List[Document]: return [doc for i, doc in enumerate(documents) if doc not in documents[:i]] class MultiQueryRetriever(BaseRetriever): """Given a query, use an LLM to write a set of queries. Retrieve docs for each query. Return the unique union of all retrieved docs. """ retriever: BaseRetriever llm_chain: LLMChain verbose: bool = True parser_key: str = "lines" """DEPRECATED. parser_key is no longer used and should not be specified.""" include_original: bool = False """Whether to include the original query in the list of generated queries.""" date_key: str = "year" top_k: int = 4 @classmethod def from_llm( cls, retriever: BaseRetriever, llm: BaseLanguageModel, prompt: PromptTemplate = MULTY_QUERY_PROMPT, parser_key: Optional[str] = None, include_original: bool = False, ) -> "MultiQueryRetriever": """Initialize from llm using default template. Args: retriever: retriever to query documents from llm: llm for query generation using DEFAULT_QUERY_PROMPT include_original: Whether to include the original query in the list of generated queries. Returns: MultiQueryRetriever """ output_parser = LineListOutputParser() llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser) return cls( retriever=retriever, llm_chain=llm_chain, include_original=include_original, ) async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, ) -> List[Document]: """Get relevant documents given a user query. Args: question: user query Returns: Unique union of relevant documents from all generated queries """ queries = await self.agenerate_queries(query, run_manager) if self.include_original: queries.append(query) documents = await self.aretrieve_documents(queries, run_manager) return self.unique_union(documents) async def agenerate_queries( self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[str]: """Generate queries based upon user input. Args: question: user query Returns: List of LLM generated queries that are similar to the user input """ response = await self.llm_chain.ainvoke( inputs={"question": question}, callbacks=run_manager.get_child() ) lines = response["text"] if self.verbose: logger.info(f"Generated queries: {lines}") return lines async def aretrieve_documents( self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: """Run all LLM generated queries. Args: queries: query list Returns: List of retrieved Documents """ document_lists = await asyncio.gather( *( self.retriever.aget_relevant_documents( query, callbacks=run_manager.get_child() ) for query in queries ) ) return [doc for docs in document_lists for doc in docs] def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: """Get relevant documents given a user query. Args: question: user query Returns: Unique union of relevant documents from all generated queries """ queries = self.generate_queries(query, run_manager) if self.include_original: queries.append(query) documents = self.retrieve_documents(queries, run_manager) fused_documents= self.unique_union(documents) # check for key exists if fused_documents[0].metadata[self.date_key] != None: doc_dates = pd.to_datetime( [doc.metadata[self.date_key] for doc in fused_documents] ) sorted_node_idxs = np.flip(doc_dates.argsort()) fused_documents = [fused_documents[idx] for idx in sorted_node_idxs] logger.info('Documents sorted by year') return fused_documents[:self.top_k] def generate_queries( self, question: str, run_manager: CallbackManagerForRetrieverRun ) -> List[str]: """Generate queries based upon user input. Args: question: user query Returns: List of LLM generated queries that are similar to the user input """ response = self.llm_chain.invoke( {"question": question}, callbacks=run_manager.get_child() ) lines = response["text"] if self.verbose: logger.info(f"Generated queries: {lines}") return lines def retrieve_documents( self, queries: List[str], run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Run all LLM generated queries. Args: queries: query list Returns: List of retrieved Documents """ documents = [] for query in queries: logger.info(f"MQ Retriever question: {query}") docs = self.retriever.get_relevant_documents( query, callbacks=run_manager.get_child() ) documents.extend(docs) return documents def unique_union(self, documents: List[Document]) -> List[Document]: """Get unique Documents. Args: documents: List of retrieved Documents Returns: List of unique retrieved Documents """ return _unique_documents(documents)