from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \ ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \ ChatDataSQLAskCallBackHandler from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever from langchain.utilities.sql_database import SQLDatabase from langchain.chains import LLMChain from sqlalchemy import create_engine, MetaData from langchain.prompts import PromptTemplate, ChatPromptTemplate, \ SystemMessagePromptTemplate, HumanMessagePromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.chat_models import ChatOpenAI from langchain import OpenAI from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName from langchain.retrievers.self_query.base import SelfQueryRetriever from langchain.retrievers.self_query.myscale import MyScaleTranslator from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings from langchain.vectorstores import MyScaleSettings from chains.arxiv_chains import MyScaleWithoutMetadataJson import re import pandas as pd from os import environ import streamlit as st import datetime environ['TOKENIZERS_PARALLELISM'] = 'true' environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE'] st.set_page_config(page_title="ChatData") st.header("ChatData") # query_model_name = "gpt-3.5-turbo-instruct" query_model_name = "text-davinci-003" chat_model_name = "gpt-3.5-turbo-16k" def hint_arxiv(): st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n" "For example: \n\n" "*If you want to search papers with complex filters*:\n\n" "- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n" "*If you want to ask questions based on papers in database*:\n\n" "- What is PageRank?\n" "- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n" "- Introduce some applications of GANs published around 2019.\n" "- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n" "- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?\n" "- Is it possible to synthesize room temperature super conductive material?") def hint_sql_arxiv(): st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡') st.markdown('''```sql CREATE TABLE default.ChatArXiv ( `abstract` String, `id` String, `vector` Array(Float32), `metadata` Object('JSON'), `pubdate` DateTime, `title` String, `categories` Array(String), `authors` Array(String), `comment` String, `primary_category` String, VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'), CONSTRAINT vec_len CHECK length(vector) = 768) ENGINE = ReplacingMergeTree ORDER BY id ```''') def hint_wiki(): st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n" "For example: \n\n" "- Which company did Elon Musk found?\n" "- What is Iron Gwazi?\n" "- What is a Ring in mathematics?\n" "- 苹果的发源地是那里?\n") def hint_sql_wiki(): st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡') st.markdown('''```sql CREATE TABLE wiki.Wikipedia ( `id` String, `title` String, `text` String, `url` String, `wiki_id` UInt64, `views` Float32, `paragraph_id` UInt64, `langs` UInt32, `emb` Array(Float32), VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'), CONSTRAINT emb_len CHECK length(emb) = 768) ENGINE = ReplacingMergeTree ORDER BY id ```''') sel_map = { 'Wikipedia': { "database": "wiki", "table": "Wikipedia", "hint": hint_wiki, "hint_sql": hint_sql_wiki, "doc_prompt": PromptTemplate( input_variables=["page_content", "url", "title", "ref_id", "views"], template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"), "metadata_cols": [ AttributeInfo( name="title", description="title of the wikipedia page", type="string", ), AttributeInfo( name="text", description="paragraph from this wiki page", type="string", ), AttributeInfo( name="views", description="number of views", type="float" ), ], "must_have_cols": ['id', 'title', 'url', 'text', 'views'], "vector_col": "emb", "text_col": "text", "metadata_col": "metadata", "emb_model": lambda: SentenceTransformerEmbeddings( model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',) }, 'ArXiv Papers': { "database": "default", "table": "ChatArXiv", "hint": hint_arxiv, "hint_sql": hint_sql_arxiv, "doc_prompt": PromptTemplate( input_variables=["page_content", "id", "title", "ref_id", "authors", "pubdate", "categories"], template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"), "metadata_cols": [ AttributeInfo( name=VirtualColumnName(name="pubdate"), description="The year the paper is published", type="timestamp", ), AttributeInfo( name="authors", description="List of author names", type="list[string]", ), AttributeInfo( name="title", description="Title of the paper", type="string", ), AttributeInfo( name="categories", description="arxiv categories to this paper", type="list[string]" ), AttributeInfo( name="length(categories)", description="length of arxiv categories to this paper", type="int" ), ], "must_have_cols": ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'], "vector_col": "vector", "text_col": "abstract", "metadata_col": "metadata", "emb_model": lambda: HuggingFaceInstructEmbeddings( model_name='hkunlp/instructor-xl', embed_instruction="Represent the question for retrieving supporting scientific papers: ") } } def try_eval(x): try: return eval(x, {'datetime': datetime}) except: return x def display(dataframe, columns_=None, index=None): if len(dataframe) > 0: if index: dataframe.set_index(index) if columns_: st.dataframe(dataframe[columns_]) else: st.dataframe(dataframe) else: st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True) def build_embedding_model(_sel): with st.spinner("Loading Model..."): embeddings = sel_map[_sel]["emb_model"]() return embeddings def build_retriever(_sel): with st.spinner(f"Connecting DB for {_sel}..."): myscale_connection = { "host": st.secrets['MYSCALE_HOST'], "port": st.secrets['MYSCALE_PORT'], "username": st.secrets['MYSCALE_USER'], "password": st.secrets['MYSCALE_PASSWORD'], } config = MyScaleSettings(**myscale_connection, database=sel_map[_sel]["database"], table=sel_map[_sel]["table"], column_map={ "id": "id", "text": sel_map[_sel]["text_col"], "vector": sel_map[_sel]["vector_col"], "metadata": sel_map[_sel]["metadata_col"] }) doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config, must_have_cols=sel_map[_sel]['must_have_cols']) with st.spinner(f"Building Self Query Retriever for {_sel}..."): metadata_field_info = sel_map[_sel]["metadata_cols"] retriever = SelfQueryRetriever.from_llm( OpenAI(model_name=query_model_name, openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0), doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info, use_original_query=False, structured_query_translator=MyScaleTranslator()) COMBINE_PROMPT = ChatPromptTemplate.from_strings( string_messages=[(SystemMessagePromptTemplate, combine_prompt_template), (HumanMessagePromptTemplate, '{question}')]) OPENAI_API_KEY = st.secrets['OPENAI_API_KEY'] with st.spinner(f'Building QA Chain with Self-query for {_sel}...'): chain = ArXivQAwithSourcesChain( retriever=retriever, combine_documents_chain=ArXivStuffDocumentChain( llm_chain=LLMChain( prompt=COMBINE_PROMPT, llm=ChatOpenAI(model_name=chat_model_name, openai_api_key=OPENAI_API_KEY, temperature=0.6), ), document_prompt=sel_map[_sel]["doc_prompt"], document_variable_name="summaries", ), return_source_documents=True, max_tokens_limit=12000, ) with st.spinner(f'Building Vector SQL Database Retriever for {_sel}...'): MYSCALE_USER = st.secrets['MYSCALE_USER'] MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD'] MYSCALE_HOST = st.secrets['MYSCALE_HOST'] MYSCALE_PORT = st.secrets['MYSCALE_PORT'] engine = create_engine( f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/{sel_map[_sel]["database"]}?protocol=https') metadata = MetaData(bind=engine) PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], template=_myscale_prompt, ) output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings( model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"]) sql_query_chain = VectorSQLDatabaseChain.from_llm( llm=OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0), prompt=PROMPT, top_k=10, return_direct=True, db=SQLDatabase(engine, None, metadata, max_string_length=1024), sql_cmd_parser=output_parser, native_format=True ) sql_retriever = VectorSQLDatabaseChainRetriever( sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"]) with st.spinner(f'Building QA Chain with Vector SQL for {_sel}...'): sql_chain = ArXivQAwithSourcesChain( retriever=sql_retriever, combine_documents_chain=ArXivStuffDocumentChain( llm_chain=LLMChain( prompt=COMBINE_PROMPT, llm=ChatOpenAI(model_name=chat_model_name, openai_api_key=OPENAI_API_KEY, temperature=0.6), ), document_prompt=sel_map[_sel]["doc_prompt"], document_variable_name="summaries", ), return_source_documents=True, max_tokens_limit=12000, ) return { "metadata_columns": [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], "retriever": retriever, "chain": chain, "sql_retriever": sql_retriever, "sql_chain": sql_chain } @st.cache_resource def build_all(): sel_map_obj = {} for k in sel_map: st.session_state[f'emb_model_{k}'] = build_embedding_model(k) sel_map_obj[k] = build_retriever(k) return sel_map_obj if 'retriever' not in st.session_state: st.session_state["sel_map_obj"] = build_all() sel = st.selectbox('Choose the knowledge base you want to ask with:', options=['ArXiv Papers', 'Wikipedia']) sel_map[sel]['hint']() tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers']) with tab_sql: sel_map[sel]['hint_sql']() st.text_input("Ask a question:", key='query_sql') cols = st.columns([1, 1, 7]) cols[0].button("Query", key='search_sql') cols[1].button("Ask", key='ask_sql') plc_hldr = st.empty() if st.session_state.search_sql: plc_hldr = st.empty() print(st.session_state.query_sql) with plc_hldr.expander('Query Log', expanded=True): callback = ChatDataSQLSearchCallBackHandler() try: docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents( st.session_state.query_sql, callbacks=[callback]) callback.progress_bar.progress(value=1.0, text="Done!") docs = pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in docs]) display(docs) except Exception as e: st.write('Oops 😵 Something bad happened...') raise e if st.session_state.ask_sql: plc_hldr = st.empty() print(st.session_state.query_sql) with plc_hldr.expander('Chat Log', expanded=True): callback = ChatDataSQLAskCallBackHandler() try: ret = st.session_state.sel_map_obj[sel]["sql_chain"]( st.session_state.query_sql, callbacks=[callback]) callback.progress_bar.progress(value=1.0, text="Done!") st.markdown( f"### Answer from LLM\n{ret['answer']}\n### References") docs = ret['sources'] docs = pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in docs]) display( docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') except Exception as e: st.write('Oops 😵 Something bad happened...') raise e with tab_self_query: st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡') st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"]) st.text_input("Ask a question:", key='query_self') cols = st.columns([1, 1, 7]) cols[0].button("Query", key='search_self') cols[1].button("Ask", key='ask_self') plc_hldr = st.empty() if st.session_state.search_self: plc_hldr = st.empty() print(st.session_state.query_self) with plc_hldr.expander('Query Log', expanded=True): call_back = None callback = ChatDataSelfSearchCallBackHandler() try: docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents( st.session_state.query_self, callbacks=[callback]) print(docs) callback.progress_bar.progress(value=1.0, text="Done!") docs = pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in docs]) display(docs, sel_map[sel]["must_have_cols"]) except Exception as e: st.write('Oops 😵 Something bad happened...') raise e if st.session_state.ask_self: plc_hldr = st.empty() print(st.session_state.query_self) with plc_hldr.expander('Chat Log', expanded=True): call_back = None callback = ChatDataSelfAskCallBackHandler() try: ret = st.session_state.sel_map_obj[sel]["chain"]( st.session_state.query_self, callbacks=[callback]) callback.progress_bar.progress(value=1.0, text="Done!") st.markdown( f"### Answer from LLM\n{ret['answer']}\n### References") docs = ret['sources'] docs = pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in docs]) display( docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') except Exception as e: st.write('Oops 😵 Something bad happened...') raise e