from langchain.chains import LLMChain from langchain.chat_models import ChatOpenAI from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate from langchain.schema import BaseRetriever import streamlit as st from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain from backend.chains.stuff_documents import CustomStuffDocumentChain from backend.constants.myscale_tables import MYSCALE_TABLES from backend.constants.prompts import COMBINE_PROMPT from backend.constants.variables import GLOBAL_CONFIG def build_retrieval_qa_with_sources_chain( table_name: str, retriever: BaseRetriever, chain_name: str = "" ) -> CustomRetrievalQAWithSourcesChain: with st.spinner(f'Building QA source chain named `{chain_name}` for MyScaleDB/{table_name} ...'): # Assign ref_id for documents custom_stuff_document_chain = CustomStuffDocumentChain( llm_chain=LLMChain( prompt=COMBINE_PROMPT, llm=ChatOpenAI( model_name=GLOBAL_CONFIG.chat_model, openai_api_key=GLOBAL_CONFIG.openai_api_key, temperature=0.6 ), ), document_prompt=MYSCALE_TABLES[table_name].doc_prompt, document_variable_name="summaries", ) chain = CustomRetrievalQAWithSourcesChain( retriever=retriever, combine_documents_chain=custom_stuff_document_chain, return_source_documents=True, max_tokens_limit=12000, ) return chain