File size: 19,089 Bytes
c3057bf
 
 
1c3ef38
c3057bf
 
 
1c3ef38
c3057bf
 
 
 
 
 
 
 
 
 
 
cfa680c
 
 
 
 
 
c3057bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c3ef38
 
 
 
 
 
 
 
c3057bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c3ef38
c3057bf
 
 
 
 
 
 
 
 
cfa680c
1c3ef38
c3057bf
 
cfa680c
1c3ef38
 
 
c3057bf
 
cfa680c
 
 
 
c3057bf
1c3ef38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3057bf
cfa680c
c3057bf
 
 
 
 
 
 
 
 
 
 
 
cfa680c
c3057bf
 
 
 
 
 
cfa680c
c3057bf
 
 
 
 
 
 
 
 
 
 
 
 
 
cfa680c
 
 
 
 
 
c3057bf
cfa680c
c3057bf
 
cfa680c
c3057bf
 
 
 
 
 
 
 
 
 
 
cfa680c
c3057bf
 
1c3ef38
 
 
 
 
 
 
 
c3057bf
cfa680c
c3057bf
 
 
 
 
 
 
 
 
 
 
 
 
 
cfa680c
c3057bf
 
 
 
cfa680c
c3057bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfa680c
c3057bf
 
 
 
 
 
 
 
 
 
 
cfa680c
c3057bf
cfa680c
c3057bf
 
 
 
 
 
 
 
cfa680c
c3057bf
 
cfa680c
c3057bf
 
 
 
 
 
 
 
 
 
 
 
cfa680c
c3057bf
 
cfa680c
 
c3057bf
 
 
 
 
 
 
 
cfa680c
c3057bf
 
 
 
 
 
 
 
 
 
 
c63e515
c3057bf
 
 
cfa680c
 
 
 
c3057bf
 
 
 
 
 
 
 
 
 
 
cfa680c
c3057bf
 
 
 
 
 
 
 
 
 
 
1c3ef38
c3057bf
 
 
 
 
 
 
 
 
cfa680c
c3057bf
cfa680c
c3057bf
 
1c3ef38
 
 
 
c3057bf
 
 
1c3ef38
c3057bf
 
 
 
 
1c3ef38
c3057bf
 
 
cfa680c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3057bf
 
 
cb081db
c3057bf
 
1c3ef38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1edb596
c3057bf
 
 
6ae72b8
c3057bf
 
cb081db
 
c3057bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c3ef38
 
 
c3057bf
1c3ef38
c3057bf
 
 
 
 
 
 
 
 
1c3ef38
c3057bf
 
1c3ef38
c3057bf
 
 
 
 
 
 
 
 
 
 
 
1c3ef38
c3057bf
1c3ef38
c3057bf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
import getpass
import os
import random
import re

from langchain_openai import ChatOpenAI
from langchain_core.globals import set_llm_cache
from langchain_core.documents import Document
from langchain_community.cache import SQLiteCache
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langgraph.graph import END, StateGraph, START
from langchain_core.output_parsers import StrOutputParser

from typing import List
from typing_extensions import TypedDict
import gradio as gr
from pydantic import BaseModel, Field

# For the reranking step
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

from prompts import IMPROVE_PROMPT, ANSWER_PROMPT, HALLUCINATION_PROMPT, RESOLVER_PROMPT, REWRITER_PROMPT

TOPICS = [
"ICT strategy management",
"IT governance management & internal controls system",
"Internal audit & compliance management",
"ICT asset & architecture management",
"ICT risk management",
"Information security & human resource security management",
"IT configuration management",
"Cryptography, certificates & key management",
"Secure network & infrastructure management",
"Backup",
"Security testing",
"Threat-led penetration testing",
"Logging",
"Data and ICT system security",
"Physical and environmental security",
"Vulnerability & patch management",
"Identity and access management",
"ICT change management",
"IT project & project portfolio management",
"Acquisition, development & maintenance of ICT systems & EUA",
"ICT incident management",
"Monitoring, availability, capacity & performance management",
"ICT outsourcing & third-party risk management",
"Subcontracting management",
"ICT provider & service level management",
"ICT business continuity management"    
]

class GradeHallucinations(BaseModel):
    """Binary score for hallucination present in generation answer."""

    binary_score: str = Field(
        description="Answer is grounded in the facts, 'yes' or 'no'"
    )

class GradeAnswer(BaseModel):
    """Binary score to assess answer addresses question."""

    binary_score: str = Field(
        description="Answer addresses the question, 'yes' or 'no'"
    )

class AnswerWithCitations(BaseModel):
    answer: str = Field(
        description="Comprehensive answer to the user's question with citations.",
    )
    citations: List[str] = Field(
        description="List of the first 20 characters of sources cited in the answer."
    )

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """

    question: str
    selected_sources: List[List[bool]]
    generation: str
    documents: List[str]
    dora_docs: List[str]
    dora_rts_docs: List[str]
    dora_news_docs: List[str]
    citations: List[str]

def _set_env(var: str):
    if os.environ.get(var):
        return
    os.environ[var] = getpass.getpass(var + ":")

def load_vectorstores(paths: list):
    # The dora vectorstore
    embd = OpenAIEmbeddings()
    model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
    compressor = CrossEncoderReranker(model=model, top_n=4)

    vectorstores = [FAISS.load_local(path, embd, allow_dangerous_deserialization=True) for path in paths]
    base_retrievers = [vectorstore.as_retriever(search_type="mmr", search_kwargs={
            "k": 7,  
            "fetch_k": 10,      
            "score_threshold": 0.8,
    }) for vectorstore in vectorstores]

    retrievers = [ContextualCompressionRetriever(
        base_compressor=compressor, base_retriever=retriever
    ) for retriever in base_retrievers]

    return retrievers

def starts_with_ignoring_blanks(full_text, prefix):
    # Normalize all types of blanks to regular spaces
    normalized_full_text = re.sub(r'\s+', ' ', full_text.strip())
    normalized_prefix = re.sub(r'\s+', ' ', prefix.strip())
    
    # Check if the normalized full text starts with the normalized prefix
    return normalized_full_text.startswith(normalized_prefix)

def match_citations_to_documents(citations: List[str], documents: List[Document]):
    """
    Matches the citations to the documents by searching for the source and section in the documents
    
    Args:
        citations (List[str]): List of citations to match
        documents (List[Document]): List of documents to search in
    
    Returns:
        dict: Dictionary with the matched documents, where the key is the citation number and the value is the matched document
    """
    matched_documents = {}
    
    for num, citation in enumerate(citations, 1):
        # Extract the relevant parts from the citation (source and section)
        print(f"checking the {num} citation: {citation}")
        for doc in documents:
            print(f"Does this: '{doc.page_content[:30]}' starts with this: '{citation}'?")
            print(f"{doc.page_content[:40] =}")
            print(f"{citation} =")
            print(f"{doc.page_content[:40].startswith(citation) =}")
            if starts_with_ignoring_blanks(doc.page_content[:40], citation): #Strangely, the 25 of the citation often become 35
                print("yes")
                if doc.metadata.get("section", None):
                    matched_documents[f"<sup>{num}</sup>"] = f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}"
                else: 
                    matched_documents[f"<sup>{num}</sup>"] = f"***{doc.metadata['source']}***: {doc.page_content}"
                break
            else: 
                print("no")

    return matched_documents

# Put all chains in fuctions
def dora_rewrite(state):
    """
    Rewrites the question to fit dora wording

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---TRANSLATE TO DORA---")
    question = state["question"]

    new_question = dora_question_rewriter.invoke({"question": question, "topics": TOPICS})

    if new_question == "Thats an interesting question, but I dont think I can answer it based on my Dora knowledge.":
        return {"question": new_question, "generation": new_question}
    else:
        return {"question": new_question}

def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    question = state["question"]
    selected_sources = state["selected_sources"]

    # Retrieval
    
    dora_docs = dora_retriever.invoke(question) if selected_sources[0] else []
    dora_rts_docs = dora_rts_retriever.invoke(question) if selected_sources[1] else []
    dora_news_docs = dora_news_retriever.invoke(question) if selected_sources[2] else []
    
    documents = dora_docs + dora_rts_docs + dora_news_docs

    return {"documents": documents, "dora_docs": dora_docs, "dora_rts_docs": dora_rts_docs, "dora_news_docs": dora_news_docs}


def generate(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # RAG generation
    answer = answer_chain.invoke({"context": documents, "question": question})

    generation = answer.answer
    print(f"{answer.citations = }")
    citations = match_citations_to_documents(answer.citations, documents)
    print(f"{len(citations)} found, is that correct?")

    return {"generation": generation, "citations": citations}

def transform_query(state):
    """
    Transform the query to produce a better question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates question key with a re-phrased question
    """

    print("---TRANSFORM QUERY---")
    question = state["question"]

    # Re-write question
    better_question = question_rewriter.invoke({"question": question})
    print(f"{better_question =}")
    return {"question": better_question}

### Edges ###
def suitable_question(state):
    """
    Determines whether the question is suitable.

    Args:
        state (dict): The current graph state

    Returns:
        str: Binary decision for next node to call
    """

    print("---ASSESSING THE QUESTION---")
    question = state["question"]
    #print(f"{question = }")
    if question == "Thats an interesting question, but I dont think I can answer it based on my Dora knowledge.":
        return "end"
    else:
        return "retrieve"

def decide_to_generate(state):
    """
    Determines whether to generate an answer, or re-generate a question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Binary decision for next node to call
    """

    print("---ASSESS GRADED DOCUMENTS---")
    documents = state["documents"]

    if not documents:
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print(
            "---DECISION: ALL DOCUMENTS ARE IRRELEVANT TO QUESTION, TRANSFORM QUERY---"
        )
        return "transform_query"
    else:
        # We have relevant documents, so generate answer
        print(f"---DECISION: GENERATE WITH {len(documents)} DOCUMENTS---")
        return "generate"

def grade_generation_v_documents_and_question(state):
    """
    Determines whether the generation is grounded in the document and answers question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Decision for next node to call
    """

    print("---CHECK HALLUCINATIONS---")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    score = hallucination_grader.invoke(
        {"documents": documents, "generation": generation}
    )
    grade = score.binary_score

    # Check hallucination
    if grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        # Check question-answering
        print("---GRADE GENERATION vs QUESTION---")
        score = answer_grader.invoke({"question": question, "generation": generation})
        grade = score.binary_score
        if grade == "yes":
            print("---DECISION: GENERATION ADDRESSES QUESTION---")
            return "useful"
        else:
            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
            return "not useful"
    else:
        print("---DECISION: THOSE DOCUMENTS ARE NOT GROUNDING THIS GENERATION---")
        return "not supported"

# Then compile the graph
def compile_graph():
    workflow = StateGraph(GraphState)
    # Define the nodes
    workflow.add_node("dora_rewrite", dora_rewrite)  
    workflow.add_node("retrieve", retrieve)  
    workflow.add_node("generate", generate)  
    workflow.add_node("transform_query", transform_query)  
    # Define the edges
    workflow.add_edge(START, "dora_rewrite")
    workflow.add_conditional_edges(
        "dora_rewrite",
        suitable_question,
        {
            "retrieve": "retrieve",
            "end": END,
        },
    )
    workflow.add_conditional_edges(
        "retrieve",
        decide_to_generate,
        {
            "transform_query": "transform_query",
            "generate": "generate",
        },
    )
    workflow.add_edge("transform_query", "retrieve")
    workflow.add_conditional_edges(
        "generate",
        grade_generation_v_documents_and_question,
        {
            "not supported": "transform_query",
            "useful": END,
            "not useful": "transform_query",
        },
    )
    # Compile
    app = workflow.compile()
    return app

# Function to interact with Gradio
def generate_response(question: str, dora: bool, rts: bool, news: bool):
    selected_sources = [dora, rts, news] if any([dora, rts, news]) else [True, False, False]
    state = app.invoke({"question": question, "selected_sources": selected_sources})
    return (
        state["generation"],
        ('\n\n'.join([f"{num} - {doc}" for num, doc in state["citations"].items()])) if "citations" in state and state["citations"] else 'No citations available.',
        # ('\n\n'.join([f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_docs"]])) if "dora_docs" in state and state["dora_docs"] else 'No documents available.',
        # ('\n\n'.join([f"***{doc.metadata['source']}, section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_rts_docs"]])) if "dora_rts_docs" in state and state["dora_rts_docs"] else 'No documents available.',
        # ('\n\n'.join([f"***{doc.metadata['source']}***: {doc.page_content}" for doc in state["dora_news_docs"]])) if "dora_news_docs" in state and state["dora_news_docs"] else 'No documents available.',
    )

def show_loading(prompt: str):
    return [prompt, "loading", "loading"]

def on_click():
    return "I would love to hear your opinion: \[email protected]"

def clear_results():
    return "", "", ""

def random_prompt():
    return random.choice([
        "How does DORA define critical ICT services and who must comply?",
        "What are the key requirements for ICT risk management under DORA?",
        "What are the reporting obligations under DORA for major incidents?",
        "What third-party risk management requirements does DORA impose?",
        "How does DORA's testing framework compare with the UK's CBEST framework?",
        "Do ICT service providers fall under DORA's regulatory requirements?",
        "How should I prepare for DORA's Threat-Led Penetration Testing (TLPT)?",
        "What role do financial supervisors play in DORA compliance?",
        "What penalties are applicable if an organization fails to comply with DORA?",
        "How does DORA align with the NIS2 Directive in Europe?",
        "Do insurance companies also fall under DORA's requirements?",
        "What are the main differences between DORA and GDPR regarding incident reporting?",
        "Are there specific resilience requirements for cloud service providers under DORA?",
        "What are the main deadlines for compliance under DORA?",
        "What steps should I take to ensure my third-party vendors are compliant with DORA?"
    ])

def load_css():
    with open('./style.css', 'r') as file:
        return file.read()

if __name__ == "__main__":
    _set_env("OPENAI_API_KEY")
    set_llm_cache(SQLiteCache(database_path=".cache.db"))

    dora_retriever, dora_rts_retriever, dora_news_retriever = load_vectorstores(
        ["./dora_vectorstore_data_faiss.vst",
        "./rts_eur_lex_vectorstore_faiss.vst",
        "./bafin_news_vectorstore_faiss.vst",]
    )

    fast_llm = ChatOpenAI(model="gpt-3.5-turbo")
    tool_llm = ChatOpenAI(model="gpt-4o")
    rewrite_llm = ChatOpenAI(model="gpt-4o", temperature=1, cache=False)

    dora_question_rewriter = IMPROVE_PROMPT | tool_llm | StrOutputParser()
    answer_chain = ANSWER_PROMPT | tool_llm.with_structured_output(
            AnswerWithCitations, include_raw=False
            ).with_config(run_name="GenerateAnswer")
    hallucination_grader = HALLUCINATION_PROMPT | fast_llm.with_structured_output(GradeHallucinations)
    answer_grader = RESOLVER_PROMPT | fast_llm.with_structured_output(GradeAnswer)
    question_rewriter = REWRITER_PROMPT | rewrite_llm | StrOutputParser()

    app = compile_graph()

    with gr.Blocks(title='Artificial Compliance', css=load_css(), fill_width=True, fill_height=True,) as demo:
        # theme=gr.themes.Monochrome(), 
        # Adding a sliding navbar
        with gr.Column(scale=1, elem_id='navbar'):
            gr.Image(
                './logo.png', 
                interactive=False, 
                show_label=False, 
                width=200,
                height=200
            )
            with gr.Column():
                dora_chatbot_button = gr.Checkbox(label="Dora", value=True, elem_classes=["navbar-button"])
                document_workbench_button = gr.Checkbox(label="Published RTS documents", value=True, elem_classes=["navbar-button"])
                newsfeed_button = gr.Checkbox(label="Bafin documents", value=True, elem_classes=["navbar-button"])
            question_prompt = gr.Textbox(
                    value=random_prompt(),
                    label='What you always wanted to know about Dora:',
                    elem_classes=['textbox'],
                    lines=6
                )
            with gr.Row():
                clear_results_button = gr.Button('Clear Results', variant='secondary', size="m")
                submit_button = gr.Button('Submit', variant='primary', size="m")

        # Adding a header
        gr.Markdown("# The Doracle", elem_id="header")
        gr.Markdown("----------------------------------------------------------------------------")
        display_prompt = gr.Markdown(
            value="", 
            label="question_prompt", 
            elem_id="header"
        )
        gr.Markdown("----------------------------------------------------------------------------")

        with gr.Column(scale=1):
            with gr.Row(elem_id='text_block'):
                llm_generation = gr.Markdown(label="LLM Generation", elem_id="llm_generation")

            gr.Markdown("----------------------------------------------------------------------------")
            
            with gr.Row(elem_id='text_block'):
                citations = gr.Markdown(label="citations", elem_id="llm_generation")

            gr.Markdown("----------------------------------------------------------------------------")

        # Adding a footer with impressum and contact
        with gr.Row(elem_classes="footer"):
            gr.Markdown("Contact", elem_id="clickable_markdown")
            invisible_btn = gr.Button("", elem_id="invisible_button")

        gr.on(
            triggers=[question_prompt.submit, submit_button.click],
            inputs=[question_prompt],
            outputs=[display_prompt, llm_generation, citations],
            fn=show_loading
        ).then(
            outputs=[llm_generation, citations],
            inputs=[question_prompt, dora_chatbot_button, document_workbench_button, newsfeed_button],
            fn=generate_response
        )

        # Use gr.on() with the invisible button's click event
        gr.on(
            triggers=[invisible_btn.click],
            fn=on_click,
            outputs=[llm_generation]
        ) 

        # Clearing out all results when the appropriate button is clicked
        clear_results_button.click(fn=clear_results, outputs=[display_prompt, llm_generation, citations])

    demo.launch()