import streamlit as st from langchain.callbacks.streamlit.streamlit_callback_handler import ( StreamlitCallbackHandler, ) from langchain.schema.output import LLMResult from sql_formatter.core import format_sql class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler): def __init__(self) -> None: self.progress_bar = st.progress(value=0.0, text="Writing SQL...") self.status_bar = st.empty() self.prog_value = 0 self.prog_interval = 0.2 def on_llm_start(self, serialized, prompts, **kwargs) -> None: pass def on_llm_end( self, response: LLMResult, *args, **kwargs, ): text = response.generations[0][0].text if text.replace(" ", "").upper().startswith("SELECT"): st.markdown("### Generated Vector Search SQL Statement \n" "> This sql statement is generated by LLM \n\n") st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""") self.prog_value += self.prog_interval self.progress_bar.progress( value=self.prog_value, text="Searching in DB...") def on_chain_start(self, serialized, inputs, **kwargs) -> None: cid = ".".join(serialized["id"]) self.prog_value += self.prog_interval self.progress_bar.progress( value=self.prog_value, text=f"Running Chain `{cid}`..." ) def on_chain_end(self, outputs, **kwargs) -> None: pass class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler): def __init__(self, table: str) -> None: self.progress_bar = st.progress(value=0.0, text="Writing SQL...") self.status_bar = st.empty() self.prog_value = 0 self.prog_interval = 0.1 self.table = table