File size: 4,432 Bytes
6aaddef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import getpass
import html
from typing import Annotated, Union
from typing_extensions import TypedDict
from langchain_community.graphs import Neo4jGraph
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint import base
from langgraph.graph import add_messages
with SqliteSaver.from_conn_string(":memory:") as mem :
memory = mem
def format_df(df):
"""
Used to display the generated plan in a nice format
Returns html code in a string
"""
def format_cell(cell):
if isinstance(cell, str):
# Encode special characters, but preserve line breaks
return html.escape(cell).replace('\n', '<br>')
return cell
# Convert the DataFrame to HTML with custom CSS
formatted_df = df.map(format_cell)
html_table = formatted_df.to_html(escape=False, index=False)
# Add custom CSS to allow multiple lines and scrolling in cells
css = """
<style>
table {
border-collapse: collapse;
width: 100%;
}
th, td {
border: 1px solid black;
padding: 8px;
text-align: left;
vertical-align: top;
white-space: pre-wrap;
max-width: 300px;
max-height: 100px;
overflow-y: auto;
}
th {
background-color: #f2f2f2;
}
</style>
"""
return css + html_table
def format_doc(doc: dict) -> str :
formatted_string = ""
for key in doc:
formatted_string += f"**{key}**: {doc[key]}\n"
return formatted_string
def _set_env(var: str, value: str = None):
if not os.environ.get(var):
if value:
os.environ[var] = value
else:
os.environ[var] = getpass.getpass(f"{var}: ")
def init_app(openai_key : str = None, groq_key : str = None, langsmith_key : str = None):
"""
Initialize app with user api keys and sets up proxy settings
"""
_set_env("GROQ_API_KEY", value=groq_key)
_set_env("LANGSMITH_API_KEY", value=langsmith_key)
_set_env("OPENAI_API_KEY", value=openai_key)
os.environ["LANGSMITH_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
os.environ["http_proxy"] = "185.46.212.98:80"
os.environ["https_proxy"] = "185.46.212.98:80"
os.environ["NO_PROXY"] = "thalescloud.io"
def clear_memory(memory, thread_id: str) -> None:
"""
Clears checkpointer state for a given thread_id, broken for now
TODO : fix this
"""
with SqliteSaver.from_conn_string(":memory:") as mem :
memory = mem
checkpoint = base.empty_checkpoint()
memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
def get_model(model : str = "mixtral-8x7b-32768"):
"""
Wrapper to return the correct llm object depending on the 'model' param
"""
if model == "gpt-4o":
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
else:
llm = ChatGroq(model=model)
return llm
class ConfigSchema(TypedDict):
graph: Neo4jGraph
plan_method: str
use_detailed_query: bool
class State(TypedDict):
messages : Annotated[list, add_messages]
store_plan : list[str]
current_plan_step : int
valid_docs : list[str]
class DocRetrieverState(TypedDict):
messages: Annotated[list, add_messages]
query: str
docs: list[dict]
cyphers: list[str]
current_plan_step : int
valid_docs: list[Union[str, dict]]
class HumanValidationState(TypedDict):
human_validated : bool
process_steps : list[str]
def update_doc_history(left : list | None, right : list | None) -> list:
"""
Reducer for the 'docs_in_processing' field.
Doesn't work currently because of bad handlinf of duplicates
TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state)
"""
if not left:
# This shouldn't happen
left = [[]]
if not right:
right = []
for i in range(len(right)):
left[i].append(right[i])
return left
class DocProcessorState(TypedDict):
valid_docs : list[Union[str, dict]]
docs_in_processing : list
process_steps : list[Union[str,dict]]
current_process_step : int
|