|
import os |
|
import getpass |
|
import html |
|
|
|
|
|
from typing import Annotated, Union |
|
from typing_extensions import TypedDict |
|
|
|
from langchain_community.graphs import Neo4jGraph |
|
from langchain_groq import ChatGroq |
|
from langchain_openai import ChatOpenAI |
|
|
|
from langgraph.checkpoint.sqlite import SqliteSaver |
|
from langgraph.checkpoint import base |
|
from langgraph.graph import add_messages |
|
|
|
with SqliteSaver.from_conn_string(":memory:") as mem : |
|
memory = mem |
|
|
|
|
|
def format_df(df): |
|
""" |
|
Used to display the generated plan in a nice format |
|
Returns html code in a string |
|
""" |
|
def format_cell(cell): |
|
if isinstance(cell, str): |
|
|
|
return html.escape(cell).replace('\n', '<br>') |
|
return cell |
|
|
|
formatted_df = df.map(format_cell) |
|
html_table = formatted_df.to_html(escape=False, index=False) |
|
|
|
|
|
css = """ |
|
<style> |
|
table { |
|
border-collapse: collapse; |
|
width: 100%; |
|
} |
|
th, td { |
|
border: 1px solid black; |
|
padding: 8px; |
|
text-align: left; |
|
vertical-align: top; |
|
white-space: pre-wrap; |
|
max-width: 300px; |
|
max-height: 100px; |
|
overflow-y: auto; |
|
} |
|
th { |
|
background-color: #f2f2f2; |
|
} |
|
</style> |
|
""" |
|
|
|
return css + html_table |
|
|
|
def format_doc(doc: dict) -> str : |
|
formatted_string = "" |
|
for key in doc: |
|
formatted_string += f"**{key}**: {doc[key]}\n" |
|
return formatted_string |
|
|
|
|
|
|
|
def _set_env(var: str, value: str = None): |
|
if not os.environ.get(var): |
|
if value: |
|
os.environ[var] = value |
|
else: |
|
os.environ[var] = getpass.getpass(f"{var}: ") |
|
|
|
|
|
def init_app(openai_key : str = None, groq_key : str = None, langsmith_key : str = None): |
|
""" |
|
Initialize app with user api keys and sets up proxy settings |
|
""" |
|
_set_env("GROQ_API_KEY", value=groq_key) |
|
_set_env("LANGSMITH_API_KEY", value=langsmith_key) |
|
_set_env("OPENAI_API_KEY", value=openai_key) |
|
os.environ["LANGSMITH_TRACING_V2"] = "true" |
|
os.environ["LANGCHAIN_PROJECT"] = "3GPP Test" |
|
os.environ["http_proxy"] = "185.46.212.98:80" |
|
os.environ["https_proxy"] = "185.46.212.98:80" |
|
os.environ["NO_PROXY"] = "thalescloud.io" |
|
|
|
def clear_memory(memory, thread_id: str) -> None: |
|
""" |
|
Clears checkpointer state for a given thread_id, broken for now |
|
TODO : fix this |
|
""" |
|
with SqliteSaver.from_conn_string(":memory:") as mem : |
|
memory = mem |
|
checkpoint = base.empty_checkpoint() |
|
memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={}) |
|
|
|
def get_model(model : str = "mixtral-8x7b-32768"): |
|
""" |
|
Wrapper to return the correct llm object depending on the 'model' param |
|
""" |
|
if model == "gpt-4o": |
|
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/") |
|
else: |
|
llm = ChatGroq(model=model) |
|
return llm |
|
|
|
|
|
class ConfigSchema(TypedDict): |
|
graph: Neo4jGraph |
|
plan_method: str |
|
use_detailed_query: bool |
|
|
|
class State(TypedDict): |
|
messages : Annotated[list, add_messages] |
|
store_plan : list[str] |
|
current_plan_step : int |
|
valid_docs : list[str] |
|
|
|
class DocRetrieverState(TypedDict): |
|
messages: Annotated[list, add_messages] |
|
query: str |
|
docs: list[dict] |
|
cyphers: list[str] |
|
current_plan_step : int |
|
valid_docs: list[Union[str, dict]] |
|
|
|
class HumanValidationState(TypedDict): |
|
human_validated : bool |
|
process_steps : list[str] |
|
|
|
def update_doc_history(left : list | None, right : list | None) -> list: |
|
""" |
|
Reducer for the 'docs_in_processing' field. |
|
Doesn't work currently because of bad handlinf of duplicates |
|
TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state) |
|
""" |
|
if not left: |
|
|
|
left = [[]] |
|
if not right: |
|
right = [] |
|
|
|
for i in range(len(right)): |
|
left[i].append(right[i]) |
|
return left |
|
|
|
|
|
class DocProcessorState(TypedDict): |
|
valid_docs : list[Union[str, dict]] |
|
docs_in_processing : list |
|
process_steps : list[Union[str,dict]] |
|
current_process_step : int |
|
|