File size: 4,432 Bytes
6aaddef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import getpass
import html


from typing import Annotated, Union
from typing_extensions import TypedDict

from langchain_community.graphs import Neo4jGraph
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI

from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint import base
from langgraph.graph import add_messages

with SqliteSaver.from_conn_string(":memory:") as mem : 
    memory = mem


def format_df(df):
    """
    Used to display the generated plan in a nice format
    Returns html code in a string
    """
    def format_cell(cell):
        if isinstance(cell, str):
            # Encode special characters, but preserve line breaks
            return html.escape(cell).replace('\n', '<br>')
        return cell
    # Convert the DataFrame to HTML with custom CSS
    formatted_df = df.map(format_cell)
    html_table = formatted_df.to_html(escape=False, index=False)
    
    # Add custom CSS to allow multiple lines and scrolling in cells
    css = """
    <style>
        table {
            border-collapse: collapse;
            width: 100%;
        }
        th, td {
            border: 1px solid black;
            padding: 8px;
            text-align: left;
            vertical-align: top;
            white-space: pre-wrap;
            max-width: 300px;
            max-height: 100px;
            overflow-y: auto;
        }
        th {
            background-color: #f2f2f2;
        }
    </style>
    """
    
    return css + html_table

def format_doc(doc: dict) -> str :
    formatted_string = ""
    for key in doc:
        formatted_string += f"**{key}**: {doc[key]}\n"
    return formatted_string



def _set_env(var: str, value: str = None):
    if not os.environ.get(var):
        if value:
            os.environ[var] = value
        else:
            os.environ[var] = getpass.getpass(f"{var}: ")


def init_app(openai_key : str = None, groq_key : str = None, langsmith_key : str = None):
    """
    Initialize app with user api keys and sets up proxy settings
    """
    _set_env("GROQ_API_KEY", value=groq_key)
    _set_env("LANGSMITH_API_KEY", value=langsmith_key)
    _set_env("OPENAI_API_KEY", value=openai_key)
    os.environ["LANGSMITH_TRACING_V2"] = "true"
    os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
    os.environ["http_proxy"] = "185.46.212.98:80"
    os.environ["https_proxy"] = "185.46.212.98:80"
    os.environ["NO_PROXY"] = "thalescloud.io"

def clear_memory(memory, thread_id: str) -> None:
    """
    Clears checkpointer state for a given thread_id, broken for now
    TODO : fix this
    """
    with SqliteSaver.from_conn_string(":memory:") as mem : 
        memory = mem
    checkpoint = base.empty_checkpoint()
    memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})  

def get_model(model : str = "mixtral-8x7b-32768"):
    """
    Wrapper to return the correct llm object depending on the 'model' param
    """
    if model == "gpt-4o":
        llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
    else:
        llm = ChatGroq(model=model)
    return llm


class ConfigSchema(TypedDict):
    graph: Neo4jGraph
    plan_method: str
    use_detailed_query: bool

class State(TypedDict):
    messages : Annotated[list, add_messages]
    store_plan : list[str]
    current_plan_step : int
    valid_docs : list[str]

class DocRetrieverState(TypedDict):
    messages: Annotated[list, add_messages]
    query: str
    docs: list[dict]
    cyphers: list[str]
    current_plan_step : int
    valid_docs: list[Union[str, dict]]

class HumanValidationState(TypedDict):
    human_validated : bool
    process_steps : list[str]

def update_doc_history(left : list | None, right : list | None) -> list:
    """
    Reducer for the 'docs_in_processing' field.
    Doesn't work currently because of bad handlinf of duplicates
    TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state)
    """
    if not left:
        # This shouldn't happen
        left = [[]]
    if not right:
        right = []

    for i in range(len(right)):
        left[i].append(right[i])
    return left


class DocProcessorState(TypedDict):
    valid_docs : list[Union[str, dict]]
    docs_in_processing : list
    process_steps : list[Union[str,dict]]
    current_process_step : int