|
import streamlit as st |
|
from langchain.callbacks.streamlit.streamlit_callback_handler import ( |
|
StreamlitCallbackHandler, |
|
) |
|
from langchain.schema.output import LLMResult |
|
from sql_formatter.core import format_sql |
|
|
|
|
|
class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler): |
|
def __init__(self) -> None: |
|
self.progress_bar = st.progress(value=0.0, text="Writing SQL...") |
|
self.status_bar = st.empty() |
|
self.prog_value = 0 |
|
self.prog_interval = 0.2 |
|
|
|
def on_llm_start(self, serialized, prompts, **kwargs) -> None: |
|
pass |
|
|
|
def on_llm_end( |
|
self, |
|
response: LLMResult, |
|
*args, |
|
**kwargs, |
|
): |
|
text = response.generations[0][0].text |
|
if text.replace(" ", "").upper().startswith("SELECT"): |
|
st.markdown("### Generated Vector Search SQL Statement \n" |
|
"> This sql statement is generated by LLM \n\n") |
|
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""") |
|
self.prog_value += self.prog_interval |
|
self.progress_bar.progress( |
|
value=self.prog_value, text="Searching in DB...") |
|
|
|
def on_chain_start(self, serialized, inputs, **kwargs) -> None: |
|
cid = ".".join(serialized["id"]) |
|
self.prog_value += self.prog_interval |
|
self.progress_bar.progress( |
|
value=self.prog_value, text=f"Running Chain `{cid}`..." |
|
) |
|
|
|
def on_chain_end(self, outputs, **kwargs) -> None: |
|
pass |
|
|
|
|
|
class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler): |
|
def __init__(self, table: str) -> None: |
|
self.progress_bar = st.progress(value=0.0, text="Writing SQL...") |
|
self.status_bar = st.empty() |
|
self.prog_value = 0 |
|
self.prog_interval = 0.1 |
|
self.table = table |
|
|
|
|
|
|