File size: 1,834 Bytes
e931b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import streamlit as st
from langchain.callbacks.streamlit.streamlit_callback_handler import (
    StreamlitCallbackHandler,
)
from langchain.schema.output import LLMResult
from sql_formatter.core import format_sql


class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler):
    def __init__(self) -> None:
        self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
        self.status_bar = st.empty()
        self.prog_value = 0
        self.prog_interval = 0.2

    def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        pass

    def on_llm_end(
            self,
            response: LLMResult,
            *args,
            **kwargs,
    ):
        text = response.generations[0][0].text
        if text.replace(" ", "").upper().startswith("SELECT"):
            st.markdown("### Generated Vector Search SQL Statement \n"
                        "> This sql statement is generated by LLM \n\n")
            st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
            self.prog_value += self.prog_interval
            self.progress_bar.progress(
                value=self.prog_value, text="Searching in DB...")

    def on_chain_start(self, serialized, inputs, **kwargs) -> None:
        cid = ".".join(serialized["id"])
        self.prog_value += self.prog_interval
        self.progress_bar.progress(
            value=self.prog_value, text=f"Running Chain `{cid}`..."
        )

    def on_chain_end(self, outputs, **kwargs) -> None:
        pass


class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler):
    def __init__(self, table: str) -> None:
        self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
        self.status_bar = st.empty()
        self.prog_value = 0
        self.prog_interval = 0.1
        self.table = table