bjhk / ki_gen /utils.py
heymenn's picture
Upload 15 files
6aaddef verified
import os
import getpass
import html
from typing import Annotated, Union
from typing_extensions import TypedDict
from langchain_community.graphs import Neo4jGraph
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint import base
from langgraph.graph import add_messages
with SqliteSaver.from_conn_string(":memory:") as mem :
memory = mem
def format_df(df):
"""
Used to display the generated plan in a nice format
Returns html code in a string
"""
def format_cell(cell):
if isinstance(cell, str):
# Encode special characters, but preserve line breaks
return html.escape(cell).replace('\n', '<br>')
return cell
# Convert the DataFrame to HTML with custom CSS
formatted_df = df.map(format_cell)
html_table = formatted_df.to_html(escape=False, index=False)
# Add custom CSS to allow multiple lines and scrolling in cells
css = """
<style>
table {
border-collapse: collapse;
width: 100%;
}
th, td {
border: 1px solid black;
padding: 8px;
text-align: left;
vertical-align: top;
white-space: pre-wrap;
max-width: 300px;
max-height: 100px;
overflow-y: auto;
}
th {
background-color: #f2f2f2;
}
</style>
"""
return css + html_table
def format_doc(doc: dict) -> str :
formatted_string = ""
for key in doc:
formatted_string += f"**{key}**: {doc[key]}\n"
return formatted_string
def _set_env(var: str, value: str = None):
if not os.environ.get(var):
if value:
os.environ[var] = value
else:
os.environ[var] = getpass.getpass(f"{var}: ")
def init_app(openai_key : str = None, groq_key : str = None, langsmith_key : str = None):
"""
Initialize app with user api keys and sets up proxy settings
"""
_set_env("GROQ_API_KEY", value=groq_key)
_set_env("LANGSMITH_API_KEY", value=langsmith_key)
_set_env("OPENAI_API_KEY", value=openai_key)
os.environ["LANGSMITH_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
os.environ["http_proxy"] = "185.46.212.98:80"
os.environ["https_proxy"] = "185.46.212.98:80"
os.environ["NO_PROXY"] = "thalescloud.io"
def clear_memory(memory, thread_id: str) -> None:
"""
Clears checkpointer state for a given thread_id, broken for now
TODO : fix this
"""
with SqliteSaver.from_conn_string(":memory:") as mem :
memory = mem
checkpoint = base.empty_checkpoint()
memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
def get_model(model : str = "mixtral-8x7b-32768"):
"""
Wrapper to return the correct llm object depending on the 'model' param
"""
if model == "gpt-4o":
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
else:
llm = ChatGroq(model=model)
return llm
class ConfigSchema(TypedDict):
graph: Neo4jGraph
plan_method: str
use_detailed_query: bool
class State(TypedDict):
messages : Annotated[list, add_messages]
store_plan : list[str]
current_plan_step : int
valid_docs : list[str]
class DocRetrieverState(TypedDict):
messages: Annotated[list, add_messages]
query: str
docs: list[dict]
cyphers: list[str]
current_plan_step : int
valid_docs: list[Union[str, dict]]
class HumanValidationState(TypedDict):
human_validated : bool
process_steps : list[str]
def update_doc_history(left : list | None, right : list | None) -> list:
"""
Reducer for the 'docs_in_processing' field.
Doesn't work currently because of bad handlinf of duplicates
TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state)
"""
if not left:
# This shouldn't happen
left = [[]]
if not right:
right = []
for i in range(len(right)):
left[i].append(right[i])
return left
class DocProcessorState(TypedDict):
valid_docs : list[Union[str, dict]]
docs_in_processing : list
process_steps : list[Union[str,dict]]
current_process_step : int