Spaces:
Sleeping
Sleeping
import getpass | |
import os | |
import random | |
import re | |
from langchain_openai import ChatOpenAI | |
from langchain_core.globals import set_llm_cache | |
from langchain_core.documents import Document | |
from langchain_community.cache import SQLiteCache | |
from langchain_community.vectorstores import FAISS | |
from langchain_openai import OpenAIEmbeddings | |
from langgraph.graph import END, StateGraph, START | |
from langchain_core.output_parsers import StrOutputParser | |
from typing import List | |
from typing_extensions import TypedDict | |
import gradio as gr | |
from pydantic import BaseModel, Field | |
# For the reranking step | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers.document_compressors import CrossEncoderReranker | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
from prompts import IMPROVE_PROMPT, ANSWER_PROMPT, HALLUCINATION_PROMPT, RESOLVER_PROMPT, REWRITER_PROMPT | |
TOPICS = [ | |
"ICT strategy management", | |
"IT governance management & internal controls system", | |
"Internal audit & compliance management", | |
"ICT asset & architecture management", | |
"ICT risk management", | |
"Information security & human resource security management", | |
"IT configuration management", | |
"Cryptography, certificates & key management", | |
"Secure network & infrastructure management", | |
"Backup", | |
"Security testing", | |
"Threat-led penetration testing", | |
"Logging", | |
"Data and ICT system security", | |
"Physical and environmental security", | |
"Vulnerability & patch management", | |
"Identity and access management", | |
"ICT change management", | |
"IT project & project portfolio management", | |
"Acquisition, development & maintenance of ICT systems & EUA", | |
"ICT incident management", | |
"Monitoring, availability, capacity & performance management", | |
"ICT outsourcing & third-party risk management", | |
"Subcontracting management", | |
"ICT provider & service level management", | |
"ICT business continuity management" | |
] | |
class GradeHallucinations(BaseModel): | |
"""Binary score for hallucination present in generation answer.""" | |
binary_score: str = Field( | |
description="Answer is grounded in the facts, 'yes' or 'no'" | |
) | |
class GradeAnswer(BaseModel): | |
"""Binary score to assess answer addresses question.""" | |
binary_score: str = Field( | |
description="Answer addresses the question, 'yes' or 'no'" | |
) | |
class AnswerWithCitations(BaseModel): | |
answer: str = Field( | |
description="Comprehensive answer to the user's question with citations.", | |
) | |
citations: List[str] = Field( | |
description="List of the first 20 characters of sources cited in the answer." | |
) | |
class GraphState(TypedDict): | |
""" | |
Represents the state of our graph. | |
Attributes: | |
question: question | |
generation: LLM generation | |
documents: list of documents | |
""" | |
question: str | |
selected_sources: List[List[bool]] | |
generation: str | |
documents: List[str] | |
dora_docs: List[str] | |
dora_rts_docs: List[str] | |
dora_news_docs: List[str] | |
citations: List[str] | |
def _set_env(var: str): | |
if os.environ.get(var): | |
return | |
os.environ[var] = getpass.getpass(var + ":") | |
def load_vectorstores(paths: list): | |
# The dora vectorstore | |
embd = OpenAIEmbeddings() | |
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") | |
compressor = CrossEncoderReranker(model=model, top_n=4) | |
vectorstores = [FAISS.load_local(path, embd, allow_dangerous_deserialization=True) for path in paths] | |
base_retrievers = [vectorstore.as_retriever(search_type="mmr", search_kwargs={ | |
"k": 7, | |
"fetch_k": 10, | |
"score_threshold": 0.8, | |
}) for vectorstore in vectorstores] | |
retrievers = [ContextualCompressionRetriever( | |
base_compressor=compressor, base_retriever=retriever | |
) for retriever in base_retrievers] | |
return retrievers | |
def starts_with_ignoring_blanks(full_text, prefix): | |
# Normalize all types of blanks to regular spaces | |
normalized_full_text = re.sub(r'\s+', ' ', full_text.strip()) | |
normalized_prefix = re.sub(r'\s+', ' ', prefix.strip()) | |
# Check if the normalized full text starts with the normalized prefix | |
return normalized_full_text.startswith(normalized_prefix) | |
def match_citations_to_documents(citations: List[str], documents: List[Document]): | |
""" | |
Matches the citations to the documents by searching for the source and section in the documents | |
Args: | |
citations (List[str]): List of citations to match | |
documents (List[Document]): List of documents to search in | |
Returns: | |
dict: Dictionary with the matched documents, where the key is the citation number and the value is the matched document | |
""" | |
matched_documents = {} | |
for num, citation in enumerate(citations, 1): | |
# Extract the relevant parts from the citation (source and section) | |
print(f"checking the {num} citation: {citation}") | |
for doc in documents: | |
print(f"Does this: '{doc.page_content[:30]}' starts with this: '{citation}'?") | |
print(f"{doc.page_content[:40] =}") | |
print(f"{citation} =") | |
print(f"{doc.page_content[:40].startswith(citation) =}") | |
if starts_with_ignoring_blanks(doc.page_content[:40], citation): #Strangely, the 25 of the citation often become 35 | |
print("yes") | |
if doc.metadata.get("section", None): | |
matched_documents[f"<sup>{num}</sup>"] = f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}" | |
else: | |
matched_documents[f"<sup>{num}</sup>"] = f"***{doc.metadata['source']}***: {doc.page_content}" | |
break | |
else: | |
print("no") | |
return matched_documents | |
# Put all chains in fuctions | |
def dora_rewrite(state): | |
""" | |
Rewrites the question to fit dora wording | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, documents, that contains retrieved documents | |
""" | |
print("---TRANSLATE TO DORA---") | |
question = state["question"] | |
new_question = dora_question_rewriter.invoke({"question": question, "topics": TOPICS}) | |
if new_question == "Thats an interesting question, but I dont think I can answer it based on my Dora knowledge.": | |
return {"question": new_question, "generation": new_question} | |
else: | |
return {"question": new_question} | |
def retrieve(state): | |
""" | |
Retrieve documents | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, documents, that contains retrieved documents | |
""" | |
print("---RETRIEVE---") | |
question = state["question"] | |
selected_sources = state["selected_sources"] | |
# Retrieval | |
dora_docs = dora_retriever.invoke(question) if selected_sources[0] else [] | |
dora_rts_docs = dora_rts_retriever.invoke(question) if selected_sources[1] else [] | |
dora_news_docs = dora_news_retriever.invoke(question) if selected_sources[2] else [] | |
documents = dora_docs + dora_rts_docs + dora_news_docs | |
return {"documents": documents, "dora_docs": dora_docs, "dora_rts_docs": dora_rts_docs, "dora_news_docs": dora_news_docs} | |
def generate(state): | |
""" | |
Generate answer | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, generation, that contains LLM generation | |
""" | |
print("---GENERATE---") | |
question = state["question"] | |
documents = state["documents"] | |
# RAG generation | |
answer = answer_chain.invoke({"context": documents, "question": question}) | |
generation = answer.answer | |
print(f"{answer.citations = }") | |
citations = match_citations_to_documents(answer.citations, documents) | |
print(f"{len(citations)} found, is that correct?") | |
return {"generation": generation, "citations": citations} | |
def transform_query(state): | |
""" | |
Transform the query to produce a better question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Updates question key with a re-phrased question | |
""" | |
print("---TRANSFORM QUERY---") | |
question = state["question"] | |
# Re-write question | |
better_question = question_rewriter.invoke({"question": question}) | |
print(f"{better_question =}") | |
return {"question": better_question} | |
### Edges ### | |
def suitable_question(state): | |
""" | |
Determines whether the question is suitable. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Binary decision for next node to call | |
""" | |
print("---ASSESSING THE QUESTION---") | |
question = state["question"] | |
#print(f"{question = }") | |
if question == "Thats an interesting question, but I dont think I can answer it based on my Dora knowledge.": | |
return "end" | |
else: | |
return "retrieve" | |
def decide_to_generate(state): | |
""" | |
Determines whether to generate an answer, or re-generate a question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Binary decision for next node to call | |
""" | |
print("---ASSESS GRADED DOCUMENTS---") | |
documents = state["documents"] | |
if not documents: | |
# All documents have been filtered check_relevance | |
# We will re-generate a new query | |
print( | |
"---DECISION: ALL DOCUMENTS ARE IRRELEVANT TO QUESTION, TRANSFORM QUERY---" | |
) | |
return "transform_query" | |
else: | |
# We have relevant documents, so generate answer | |
print(f"---DECISION: GENERATE WITH {len(documents)} DOCUMENTS---") | |
return "generate" | |
def grade_generation_v_documents_and_question(state): | |
""" | |
Determines whether the generation is grounded in the document and answers question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Decision for next node to call | |
""" | |
print("---CHECK HALLUCINATIONS---") | |
question = state["question"] | |
documents = state["documents"] | |
generation = state["generation"] | |
score = hallucination_grader.invoke( | |
{"documents": documents, "generation": generation} | |
) | |
grade = score.binary_score | |
# Check hallucination | |
if grade == "yes": | |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
# Check question-answering | |
print("---GRADE GENERATION vs QUESTION---") | |
score = answer_grader.invoke({"question": question, "generation": generation}) | |
grade = score.binary_score | |
if grade == "yes": | |
print("---DECISION: GENERATION ADDRESSES QUESTION---") | |
return "useful" | |
else: | |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
return "not useful" | |
else: | |
print("---DECISION: THOSE DOCUMENTS ARE NOT GROUNDING THIS GENERATION---") | |
return "not supported" | |
# Then compile the graph | |
def compile_graph(): | |
workflow = StateGraph(GraphState) | |
# Define the nodes | |
workflow.add_node("dora_rewrite", dora_rewrite) | |
workflow.add_node("retrieve", retrieve) | |
workflow.add_node("generate", generate) | |
workflow.add_node("transform_query", transform_query) | |
# Define the edges | |
workflow.add_edge(START, "dora_rewrite") | |
workflow.add_conditional_edges( | |
"dora_rewrite", | |
suitable_question, | |
{ | |
"retrieve": "retrieve", | |
"end": END, | |
}, | |
) | |
workflow.add_conditional_edges( | |
"retrieve", | |
decide_to_generate, | |
{ | |
"transform_query": "transform_query", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("transform_query", "retrieve") | |
workflow.add_conditional_edges( | |
"generate", | |
grade_generation_v_documents_and_question, | |
{ | |
"not supported": "transform_query", | |
"useful": END, | |
"not useful": "transform_query", | |
}, | |
) | |
# Compile | |
app = workflow.compile() | |
return app | |
# Function to interact with Gradio | |
def generate_response(question: str, dora: bool, rts: bool, news: bool): | |
selected_sources = [dora, rts, news] if any([dora, rts, news]) else [True, False, False] | |
state = app.invoke({"question": question, "selected_sources": selected_sources}) | |
return ( | |
state["generation"], | |
('\n\n'.join([f"{num} - {doc}" for num, doc in state["citations"].items()])) if "citations" in state and state["citations"] else 'No citations available.', | |
# ('\n\n'.join([f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_docs"]])) if "dora_docs" in state and state["dora_docs"] else 'No documents available.', | |
# ('\n\n'.join([f"***{doc.metadata['source']}, section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_rts_docs"]])) if "dora_rts_docs" in state and state["dora_rts_docs"] else 'No documents available.', | |
# ('\n\n'.join([f"***{doc.metadata['source']}***: {doc.page_content}" for doc in state["dora_news_docs"]])) if "dora_news_docs" in state and state["dora_news_docs"] else 'No documents available.', | |
) | |
def show_loading(prompt: str): | |
return [prompt, "loading", "loading"] | |
def on_click(): | |
return "I would love to hear your opinion: \[email protected]" | |
def clear_results(): | |
return "", "", "" | |
def random_prompt(): | |
return random.choice([ | |
"How does DORA define critical ICT services and who must comply?", | |
"What are the key requirements for ICT risk management under DORA?", | |
"What are the reporting obligations under DORA for major incidents?", | |
"What third-party risk management requirements does DORA impose?", | |
"How does DORA's testing framework compare with the UK's CBEST framework?", | |
"Do ICT service providers fall under DORA's regulatory requirements?", | |
"How should I prepare for DORA's Threat-Led Penetration Testing (TLPT)?", | |
"What role do financial supervisors play in DORA compliance?", | |
"What penalties are applicable if an organization fails to comply with DORA?", | |
"How does DORA align with the NIS2 Directive in Europe?", | |
"Do insurance companies also fall under DORA's requirements?", | |
"What are the main differences between DORA and GDPR regarding incident reporting?", | |
"Are there specific resilience requirements for cloud service providers under DORA?", | |
"What are the main deadlines for compliance under DORA?", | |
"What steps should I take to ensure my third-party vendors are compliant with DORA?" | |
]) | |
def load_css(): | |
with open('./style.css', 'r') as file: | |
return file.read() | |
if __name__ == "__main__": | |
_set_env("OPENAI_API_KEY") | |
set_llm_cache(SQLiteCache(database_path=".cache.db")) | |
dora_retriever, dora_rts_retriever, dora_news_retriever = load_vectorstores( | |
["./dora_vectorstore_data_faiss.vst", | |
"./rts_eur_lex_vectorstore_faiss.vst", | |
"./bafin_news_vectorstore_faiss.vst",] | |
) | |
fast_llm = ChatOpenAI(model="gpt-3.5-turbo") | |
tool_llm = ChatOpenAI(model="gpt-4o") | |
rewrite_llm = ChatOpenAI(model="gpt-4o", temperature=1, cache=False) | |
dora_question_rewriter = IMPROVE_PROMPT | tool_llm | StrOutputParser() | |
answer_chain = ANSWER_PROMPT | tool_llm.with_structured_output( | |
AnswerWithCitations, include_raw=False | |
).with_config(run_name="GenerateAnswer") | |
hallucination_grader = HALLUCINATION_PROMPT | fast_llm.with_structured_output(GradeHallucinations) | |
answer_grader = RESOLVER_PROMPT | fast_llm.with_structured_output(GradeAnswer) | |
question_rewriter = REWRITER_PROMPT | rewrite_llm | StrOutputParser() | |
app = compile_graph() | |
with gr.Blocks(title='Artificial Compliance', css=load_css(), fill_width=True, fill_height=True,) as demo: | |
# theme=gr.themes.Monochrome(), | |
# Adding a sliding navbar | |
with gr.Column(scale=1, elem_id='navbar'): | |
gr.Image( | |
'./logo.png', | |
interactive=False, | |
show_label=False, | |
width=200, | |
height=200 | |
) | |
with gr.Column(): | |
dora_chatbot_button = gr.Checkbox(label="Dora", value=True, elem_classes=["navbar-button"]) | |
document_workbench_button = gr.Checkbox(label="Published RTS documents", value=True, elem_classes=["navbar-button"]) | |
newsfeed_button = gr.Checkbox(label="Bafin documents", value=True, elem_classes=["navbar-button"]) | |
question_prompt = gr.Textbox( | |
value=random_prompt(), | |
label='What you always wanted to know about Dora:', | |
elem_classes=['textbox'], | |
lines=6 | |
) | |
with gr.Row(): | |
clear_results_button = gr.Button('Clear Results', variant='secondary', size="m") | |
submit_button = gr.Button('Submit', variant='primary', size="m") | |
# Adding a header | |
gr.Markdown("# The Doracle", elem_id="header") | |
gr.Markdown("----------------------------------------------------------------------------") | |
display_prompt = gr.Markdown( | |
value="", | |
label="question_prompt", | |
elem_id="header" | |
) | |
gr.Markdown("----------------------------------------------------------------------------") | |
with gr.Column(scale=1): | |
with gr.Row(elem_id='text_block'): | |
llm_generation = gr.Markdown(label="LLM Generation", elem_id="llm_generation") | |
gr.Markdown("----------------------------------------------------------------------------") | |
with gr.Row(elem_id='text_block'): | |
citations = gr.Markdown(label="citations", elem_id="llm_generation") | |
gr.Markdown("----------------------------------------------------------------------------") | |
# Adding a footer with impressum and contact | |
with gr.Row(elem_classes="footer"): | |
gr.Markdown("Contact", elem_id="clickable_markdown") | |
invisible_btn = gr.Button("", elem_id="invisible_button") | |
gr.on( | |
triggers=[question_prompt.submit, submit_button.click], | |
inputs=[question_prompt], | |
outputs=[display_prompt, llm_generation, citations], | |
fn=show_loading | |
).then( | |
outputs=[llm_generation, citations], | |
inputs=[question_prompt, dora_chatbot_button, document_workbench_button, newsfeed_button], | |
fn=generate_response | |
) | |
# Use gr.on() with the invisible button's click event | |
gr.on( | |
triggers=[invisible_btn.click], | |
fn=on_click, | |
outputs=[llm_generation] | |
) | |
# Clearing out all results when the appropriate button is clicked | |
clear_results_button.click(fn=clear_results, outputs=[display_prompt, llm_generation, citations]) | |
demo.launch() | |