""" /************************************************************************* * * CONFIDENTIAL * __________________ * * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC * All Rights Reserved * * Author : Theekshana Samaradiwakara * Description :Python Backend API to chat with private data * CreatedDate : 14/11/2023 * LastModifiedDate : 18/03/2024 *************************************************************************/ """ import os import time import logging logger = logging.getLogger(__name__) from dotenv import load_dotenv from fastapi import HTTPException from llmChain import get_qa_chain, get_general_qa_chain, get_router_chain from output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser from config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE from retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever load_dotenv() verbose = os.environ.get('VERBOSE') qa_model_type=QA_MODEL_TYPE general_qa_model_type=GENERAL_QA_MODEL_TYPE router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl" multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl" # model_type="tiiuae/falcon-7b-instruct" # retriever=load_faiss_retriever() retriever=load_ensemble_retriever() # retriever=load_multi_query_retriever(multi_query_model_type) logger.info("retriever loaded:") qa_chain= get_qa_chain(qa_model_type,retriever) general_qa_chain= get_general_qa_chain(general_qa_model_type) router_chain= get_router_chain(router_model_type) def chain_selector(chain_type, query): chain_type = chain_type.lower().strip() logger.info(f"chain_selector : chain_type: {chain_type} Question: {query}") if "greeting" in chain_type: return run_general_qa_chain(query) elif "other" in chain_type: return run_out_of_domain_chain(query) elif ("relevant" in chain_type) or ("not sure" in chain_type) : return run_qa_chain(query) else: raise ValueError( f"Received invalid type '{chain_type}'" ) def run_agent(query): try: logger.info(f"run_agent : Question: {query}") print(f"---------------- run_agent : Question: {query} ----------------") # Get the answer from the chain start = time.time() chain_type = run_router_chain(query) res = chain_selector(chain_type,query) end = time.time() # log the result logger.error(f"---------------- Answer (took {round(end - start, 2)} s.) \n: {res}") print(f" \n ---------------- Answer (took {round(end - start, 2)} s.): -------------- \n") return res except HTTPException as e: print('HTTPException eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee') print(e) logger.exception(e) raise e except Exception as e: print('Exception eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee') print(e) logger.exception(e) raise e def run_router_chain(query): try: logger.info(f"run_router_chain : Question: {query}") # Get the answer from the chain start = time.time() chain_type = router_chain.invoke(query)['text'] end = time.time() # log the result logger.info(f"Answer (took {round(end - start, 2)} s.) chain_type: {chain_type}") return chain_type except Exception as e: logger.exception(e) raise e def run_qa_chain(query): try: logger.info(f"run_qa_chain : Question: {query}") # Get the answer from the chain start = time.time() # res = qa_chain(query) res = qa_chain.invoke({"question": query, "chat_history":""}) # res = response # answer, docs = res['result'],res['source_documents'] end = time.time() # log the result logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}") return qa_chain_output_parser(res) except Exception as e: logger.exception(e) raise e def run_general_qa_chain(query): try: logger.info(f"run_general_qa_chain : Question: {query}") # Get the answer from the chain start = time.time() res = general_qa_chain.invoke(query) end = time.time() # log the result logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}") return general_qa_chain_output_parser(res) except Exception as e: logger.exception(e) raise e def run_out_of_domain_chain(query): return out_of_domain_chain_parser(query)