|
import streamlit as st |
|
import json |
|
import textwrap |
|
from typing import Dict, Any, List |
|
from sql_formatter.core import format_sql |
|
from langchain.callbacks.streamlit.streamlit_callback_handler import ( |
|
LLMThought, |
|
StreamlitCallbackHandler, |
|
) |
|
from langchain.schema.output import LLMResult |
|
|
|
|
|
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler): |
|
def __init__(self) -> None: |
|
self.progress_bar = st.progress(value=0.0, text="Working...") |
|
self.tokens_stream = "" |
|
|
|
def on_llm_start(self, serialized, prompts, **kwargs) -> None: |
|
pass |
|
|
|
def on_text(self, text: str, **kwargs) -> None: |
|
self.progress_bar.progress(value=0.2, text="Asking LLM...") |
|
|
|
def on_chain_end(self, outputs, **kwargs) -> None: |
|
self.progress_bar.progress(value=0.6, text="Searching in DB...") |
|
if "repr" in outputs: |
|
st.markdown("### Generated Filter") |
|
st.markdown( |
|
f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True) |
|
|
|
def on_chain_start(self, serialized, inputs, **kwargs) -> None: |
|
pass |
|
|
|
|
|
class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler): |
|
def __init__(self) -> None: |
|
self.progress_bar = st.progress(value=0.0, text="Searching DB...") |
|
self.status_bar = st.empty() |
|
self.prog_value = 0.0 |
|
self.prog_map = { |
|
"langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain": 0.2, |
|
"langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain": 0.4, |
|
"langchain.chains.combine_documents.stuff.StuffDocumentsChain": 0.8, |
|
} |
|
|
|
def on_llm_start(self, serialized, prompts, **kwargs) -> None: |
|
pass |
|
|
|
def on_text(self, text: str, **kwargs) -> None: |
|
pass |
|
|
|
def on_chain_start(self, serialized, inputs, **kwargs) -> None: |
|
cid = ".".join(serialized["id"]) |
|
if cid != "langchain.chains.llm.LLMChain": |
|
self.progress_bar.progress( |
|
value=self.prog_map[cid], text=f"Running Chain `{cid}`..." |
|
) |
|
self.prog_value = self.prog_map[cid] |
|
else: |
|
self.prog_value += 0.1 |
|
self.progress_bar.progress( |
|
value=self.prog_value, text=f"Running Chain `{cid}`..." |
|
) |
|
|
|
def on_chain_end(self, outputs, **kwargs) -> None: |
|
pass |
|
|
|
|
|
class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler): |
|
def __init__(self) -> None: |
|
self.progress_bar = st.progress(value=0.0, text="Writing SQL...") |
|
self.status_bar = st.empty() |
|
self.prog_value = 0 |
|
self.prog_interval = 0.2 |
|
|
|
def on_llm_start(self, serialized, prompts, **kwargs) -> None: |
|
pass |
|
|
|
def on_llm_end( |
|
self, |
|
response: LLMResult, |
|
*args, |
|
**kwargs, |
|
): |
|
text = response.generations[0][0].text |
|
if text.replace(" ", "").upper().startswith("SELECT"): |
|
st.write("We generated Vector SQL for you:") |
|
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""") |
|
print(f"Vector SQL: {text}") |
|
self.prog_value += self.prog_interval |
|
self.progress_bar.progress( |
|
value=self.prog_value, text="Searching in DB...") |
|
|
|
def on_chain_start(self, serialized, inputs, **kwargs) -> None: |
|
cid = ".".join(serialized["id"]) |
|
self.prog_value += self.prog_interval |
|
self.progress_bar.progress( |
|
value=self.prog_value, text=f"Running Chain `{cid}`..." |
|
) |
|
|
|
def on_chain_end(self, outputs, **kwargs) -> None: |
|
pass |
|
|
|
|
|
class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler): |
|
def __init__(self) -> None: |
|
self.progress_bar = st.progress(value=0.0, text="Writing SQL...") |
|
self.status_bar = st.empty() |
|
self.prog_value = 0 |
|
self.prog_interval = 0.1 |
|
|
|
|
|
class LLMThoughtWithKB(LLMThought): |
|
def on_tool_end( |
|
self, |
|
output: str, |
|
color=None, |
|
observation_prefix=None, |
|
llm_prefix=None, |
|
**kwargs: Any, |
|
) -> None: |
|
try: |
|
self._container.markdown( |
|
"\n\n".join( |
|
["### Retrieved Documents:"] |
|
+ [ |
|
f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}" |
|
for i, r in enumerate(json.loads(output)) |
|
] |
|
) |
|
) |
|
except Exception as e: |
|
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs) |
|
|
|
|
|
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler): |
|
def on_llm_start( |
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any |
|
) -> None: |
|
if self._current_thought is None: |
|
self._current_thought = LLMThoughtWithKB( |
|
parent_container=self._parent_container, |
|
expanded=self._expand_new_thoughts, |
|
collapse_on_complete=self._collapse_completed_thoughts, |
|
labeler=self._thought_labeler, |
|
) |
|
|
|
self._current_thought.on_llm_start(serialized, prompts) |
|
|