import streamlit as st import json import textwrap from typing import Dict, Any, List from sql_formatter.core import format_sql from langchain.callbacks.streamlit.streamlit_callback_handler import ( LLMThought, StreamlitCallbackHandler, ) from langchain.schema.output import LLMResult class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler): def __init__(self) -> None: self.progress_bar = st.progress(value=0.0, text="Working...") self.tokens_stream = "" def on_llm_start(self, serialized, prompts, **kwargs) -> None: pass def on_text(self, text: str, **kwargs) -> None: self.progress_bar.progress(value=0.2, text="Asking LLM...") def on_chain_end(self, outputs, **kwargs) -> None: self.progress_bar.progress(value=0.6, text="Searching in DB...") if "repr" in outputs: st.markdown("### Generated Filter") st.markdown( f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True) def on_chain_start(self, serialized, inputs, **kwargs) -> None: pass class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler): def __init__(self) -> None: self.progress_bar = st.progress(value=0.0, text="Searching DB...") self.status_bar = st.empty() self.prog_value = 0.0 self.prog_map = { "langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain": 0.2, "langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain": 0.4, "langchain.chains.combine_documents.stuff.StuffDocumentsChain": 0.8, } def on_llm_start(self, serialized, prompts, **kwargs) -> None: pass def on_text(self, text: str, **kwargs) -> None: pass def on_chain_start(self, serialized, inputs, **kwargs) -> None: cid = ".".join(serialized["id"]) if cid != "langchain.chains.llm.LLMChain": self.progress_bar.progress( value=self.prog_map[cid], text=f"Running Chain `{cid}`..." ) self.prog_value = self.prog_map[cid] else: self.prog_value += 0.1 self.progress_bar.progress( value=self.prog_value, text=f"Running Chain `{cid}`..." ) def on_chain_end(self, outputs, **kwargs) -> None: pass class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler): def __init__(self) -> None: self.progress_bar = st.progress(value=0.0, text="Writing SQL...") self.status_bar = st.empty() self.prog_value = 0 self.prog_interval = 0.2 def on_llm_start(self, serialized, prompts, **kwargs) -> None: pass def on_llm_end( self, response: LLMResult, *args, **kwargs, ): text = response.generations[0][0].text if text.replace(" ", "").upper().startswith("SELECT"): st.write("We generated Vector SQL for you:") st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""") print(f"Vector SQL: {text}") self.prog_value += self.prog_interval self.progress_bar.progress( value=self.prog_value, text="Searching in DB...") def on_chain_start(self, serialized, inputs, **kwargs) -> None: cid = ".".join(serialized["id"]) self.prog_value += self.prog_interval self.progress_bar.progress( value=self.prog_value, text=f"Running Chain `{cid}`..." ) def on_chain_end(self, outputs, **kwargs) -> None: pass class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler): def __init__(self) -> None: self.progress_bar = st.progress(value=0.0, text="Writing SQL...") self.status_bar = st.empty() self.prog_value = 0 self.prog_interval = 0.1 class LLMThoughtWithKB(LLMThought): def on_tool_end( self, output: str, color=None, observation_prefix=None, llm_prefix=None, **kwargs: Any, ) -> None: try: self._container.markdown( "\n\n".join( ["### Retrieved Documents:"] + [ f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}" for i, r in enumerate(json.loads(output)) ] ) ) except Exception as e: super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs) class ChatDataAgentCallBackHandler(StreamlitCallbackHandler): def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: if self._current_thought is None: self._current_thought = LLMThoughtWithKB( parent_container=self._parent_container, expanded=self._expand_new_thoughts, collapse_on_complete=self._collapse_completed_thoughts, labeler=self._thought_labeler, ) self._current_thought.on_llm_start(serialized, prompts)