Spaces:
Sleeping
Sleeping
import getpass | |
import os | |
import random | |
from langchain_openai import ChatOpenAI | |
from langchain_core.globals import set_llm_cache | |
from langchain_community.cache import SQLiteCache | |
from langchain_community.vectorstores import FAISS | |
from langchain_openai import OpenAIEmbeddings | |
from langgraph.graph import END, StateGraph, START | |
from langchain_core.output_parsers import StrOutputParser | |
import asyncio | |
from typing import List | |
from typing_extensions import TypedDict | |
import gradio as gr | |
from pydantic import BaseModel, Field | |
from prompts import IMPROVE_PROMPT, RELEVANCE_PROMPT, ANSWER_PROMPT, HALLUCINATION_PROMPT, RESOLVER_PROMPT, REWRITER_PROMPT | |
TOPICS = [ | |
"ICT strategy management", | |
"IT governance management & internal controls system", | |
"Internal audit & compliance management", | |
"ICT asset & architecture management", | |
"ICT risk management", | |
"Information security & human resource security management", | |
"IT configuration management", | |
"Cryptography, certificates & key management", | |
"Secure network & infrastructure management", | |
"Backup", | |
"Security testing", | |
"Threat-led penetration testing", | |
"Logging", | |
"Data and ICT system security", | |
"Physical and environmental security", | |
"Vulnerability & patch management", | |
"Identity and access management", | |
"ICT change management", | |
"IT project & project portfolio management", | |
"Acquisition, development & maintenance of ICT systems & EUA", | |
"ICT incident management", | |
"Monitoring, availability, capacity & performance management", | |
"ICT outsourcing & third-party risk management", | |
"Subcontracting management", | |
"ICT provider & service level management", | |
"ICT business continuity management" | |
] | |
class GradeDocuments(BaseModel): | |
"""Binary score for relevance check on retrieved documents.""" | |
binary_score: str = Field( | |
description="Documents are relevant to the question, 'yes' or 'no'" | |
) | |
class GradeHallucinations(BaseModel): | |
"""Binary score for hallucination present in generation answer.""" | |
binary_score: str = Field( | |
description="Answer is grounded in the facts, 'yes' or 'no'" | |
) | |
class GradeAnswer(BaseModel): | |
"""Binary score to assess answer addresses question.""" | |
binary_score: str = Field( | |
description="Answer addresses the question, 'yes' or 'no'" | |
) | |
class GraphState(TypedDict): | |
""" | |
Represents the state of our graph. | |
Attributes: | |
question: question | |
generation: LLM generation | |
documents: list of documents | |
""" | |
question: str | |
selected_sources: List[List[bool]] | |
generation: str | |
documents: List[str] | |
fitting_documents: List[str] | |
dora_docs: List[str] | |
dora_rts_docs: List[str] | |
dora_news_docs: List[str] | |
def _set_env(var: str): | |
if os.environ.get(var): | |
return | |
os.environ[var] = getpass.getpass(var + ":") | |
def load_vectorstores(paths: list): | |
# The dora vectorstore | |
embd = OpenAIEmbeddings() | |
vectorstores = [FAISS.load_local(path, embd, allow_dangerous_deserialization=True) for path in paths] | |
retrievers = [vectorstore.as_retriever(search_type="mmr", search_kwargs={ | |
"k": 7, | |
"fetch_k": 10, | |
"score_threshold": 0.7, | |
}) for vectorstore in vectorstores] | |
return retrievers | |
# Put all chains in fuctions | |
async def dora_rewrite(state): | |
""" | |
Rewrites the question to fit dora wording | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, documents, that contains retrieved documents | |
""" | |
print("---TRANSLATE TO DORA---") | |
question = state["question"] | |
new_question = await dora_question_rewriter.ainvoke({"question": question, "topics": TOPICS}) | |
if new_question == "Thats an interesting question, but I dont think I can answer it based on my Dora knowledge.": | |
return {"question": new_question, "generation": new_question} | |
else: | |
return {"question": new_question} | |
async def retrieve(state): | |
""" | |
Retrieve documents | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, documents, that contains retrieved documents | |
""" | |
print("---RETRIEVE---") | |
question = state["question"] | |
selected_sources = state["selected_sources"] | |
# Retrieval | |
documents = [] | |
if selected_sources[0]: | |
documents.extend(await dora_retriever.ainvoke(question)) | |
if selected_sources[1]: | |
documents.extend(await dora_rts_retriever.ainvoke(question)) | |
if selected_sources[2]: | |
documents.extend(await dora_news_retriever.ainvoke(question)) | |
return {"documents": documents, "question": question} | |
async def grade_documents(state): | |
""" | |
Determines whether the retrieved documents are relevant to the question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Updates documents key with only filtered relevant documents | |
""" | |
print("---CHECK DOCUMENTS RELEVANCE TO QUESTION---") | |
question = state["question"] | |
documents = state["documents"] | |
fitting_documents = state["fitting_documents"] if "fitting_documents" in state else [] | |
# Score each doc | |
for d in documents: | |
score = await retrieval_grader.ainvoke( | |
{"question": question, "document": d.page_content} | |
) | |
grade = score.binary_score | |
if grade == "yes": | |
#print("---GRADE: DOCUMENT RELEVANT---") | |
if d in fitting_documents: | |
#print(f"---Document {d.page_content} already in fitting documents---") | |
continue | |
fitting_documents.append(d) | |
else: | |
#print("---GRADE: DOCUMENT NOT RELEVANT---") | |
continue | |
return {"fitting_documents": fitting_documents} | |
async def generate(state): | |
""" | |
Generate answer | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, generation, that contains LLM generation | |
""" | |
print("---GENERATE---") | |
question = state["question"] | |
fitting_documents = state["fitting_documents"] | |
dora_docs = [d for d in fitting_documents if d.metadata["source"].startswith("Dora")] | |
dora_rts_docs = [d for d in fitting_documents if d.metadata["source"].startswith("Commission")] | |
dora_news_docs = [d for d in fitting_documents if d.metadata["source"].startswith("https")] | |
# RAG generation | |
generation = await answer_chain.ainvoke({"context": fitting_documents, "question": question}) | |
return {"generation": generation, "dora_docs": dora_docs, "dora_rts_docs": dora_rts_docs, "dora_news_docs": dora_news_docs} | |
async def transform_query(state): | |
""" | |
Transform the query to produce a better question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Updates question key with a re-phrased question | |
""" | |
print("---TRANSFORM QUERY---") | |
question = state["question"] | |
# Re-write question | |
better_question = await question_rewriter.ainvoke({"question": question}) | |
print(f"{better_question =}") | |
return {"question": better_question} | |
### Edges ### | |
async def suitable_question(state): | |
""" | |
Determines whether the question is suitable. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Binary decision for next node to call | |
""" | |
print("---ASSESSING THE QUESTION---") | |
question = state["question"] | |
#print(f"{question = }") | |
if question == "Thats an interesting question, but I dont think I can answer it based on my Dora knowledge.": | |
return "end" | |
else: | |
return "retrieve" | |
async def decide_to_generate(state): | |
""" | |
Determines whether to generate an answer, or re-generate a question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Binary decision for next node to call | |
""" | |
print("---ASSESS GRADED DOCUMENTS---") | |
fitting_documents = state["fitting_documents"] | |
if not fitting_documents: | |
# All documents have been filtered check_relevance | |
# We will re-generate a new query | |
print( | |
"---DECISION: ALL DOCUMENTS ARE IRRELEVANT TO QUESTION, TRANSFORM QUERY---" | |
) | |
return "transform_query" | |
else: | |
# We have relevant documents, so generate answer | |
print(f"---DECISION: GENERATE WITH {len(fitting_documents)} DOCUMENTS---") | |
return "generate" | |
async def grade_generation_v_documents_and_question(state): | |
""" | |
Determines whether the generation is grounded in the document and answers question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Decision for next node to call | |
""" | |
print("---CHECK HALLUCINATIONS---") | |
question = state["question"] | |
fitting_documents = state["fitting_documents"] | |
generation = state["generation"] | |
score = await hallucination_grader.ainvoke( | |
{"documents": fitting_documents, "generation": generation} | |
) | |
grade = score.binary_score | |
# Check hallucination | |
if grade == "yes": | |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
# Check question-answering | |
print("---GRADE GENERATION vs QUESTION---") | |
score = await answer_grader.ainvoke({"question": question, "generation": generation}) | |
grade = score.binary_score | |
if grade == "yes": | |
print("---DECISION: GENERATION ADDRESSES QUESTION---") | |
return "useful" | |
else: | |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
return "not useful" | |
else: | |
for document in fitting_documents: | |
print(document.page_content) | |
print("---DECISION: THOSE DOCUMENTS ARE NOT GROUNDING THIS GENERATION---") | |
print(f"{generation = }") | |
return "not supported" | |
# Then compile the graph | |
def compile_graph(): | |
workflow = StateGraph(GraphState) | |
# Define the nodes | |
workflow.add_node("dora_rewrite", dora_rewrite) # retrieve | |
workflow.add_node("retrieve", retrieve) # retrieve | |
workflow.add_node("grade_documents", grade_documents) # grade documents | |
workflow.add_node("generate", generate) # generate | |
workflow.add_node("transform_query", transform_query) # transform_query | |
# Define the edges | |
workflow.add_edge(START, "dora_rewrite") | |
workflow.add_conditional_edges( | |
"dora_rewrite", | |
suitable_question, | |
{ | |
"retrieve": "retrieve", | |
"end": END, | |
}, | |
) | |
workflow.add_edge("retrieve", "grade_documents") | |
workflow.add_conditional_edges( | |
"grade_documents", | |
decide_to_generate, | |
{ | |
"transform_query": "transform_query", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("transform_query", "retrieve") | |
workflow.add_conditional_edges( | |
"generate", | |
grade_generation_v_documents_and_question, | |
{ | |
"not supported": "generate", | |
"useful": END, | |
"not useful": "transform_query", | |
}, | |
) | |
# Compile | |
app = workflow.compile() | |
return app | |
# Function to interact with Gradio | |
async def generate_response(question: str, dora: bool, rts: bool, news: bool): | |
selected_sources = [dora, rts, news] if any([dora, rts, news]) else [True, False, False] | |
state = await app.ainvoke({"question": question, "selected_sources": selected_sources}) | |
return ( | |
state["generation"], | |
('\n\n'.join([f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_docs"]])) if "dora_docs" in state and state["dora_docs"] else 'No documents available.', | |
('\n\n'.join([f"***{doc.metadata['source']}, section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_rts_docs"]])) if "dora_rts_docs" in state and state["dora_rts_docs"] else 'No documents available.', | |
('\n\n'.join([f"***{doc.metadata['source']}***: {doc.page_content}" for doc in state["dora_news_docs"]])) if "dora_news_docs" in state and state["dora_news_docs"] else 'No documents available.', | |
) | |
def show_loading(prompt: str): | |
return [prompt, "loading", "loading", "loading", "loading"] | |
def on_click(): | |
return "I would love to hear your opinion: \[email protected]" | |
def clear_results(): | |
return "", "", "", "", "" | |
def random_prompt(): | |
return random.choice([ | |
"Was ist der Unterschied zwischen TIBER-EU und DORA TLPT?", | |
"Ich möchte ein SIEM einführen. Bitte gib mir eine Checkliste, was ich beachten muss.", | |
"Was ist der Geltungsbereich der DORA? Bin ich als Finanzdienstleister im Leasinggeschäft betroffen?", | |
"Ich hatte einen Ransomwarevorfall mit erheblichen Auswirkungen auf den Geschäftsbetrieb. Muss ich etwas melden?", | |
"Was ist dieses DORA überhaupt?" | |
]) | |
def load_css(): | |
with open('style.css', 'r') as file: | |
return file.read() | |
def run_gradio(): | |
with gr.Blocks(title='Artificial Compliance', theme=gr.themes.Monochrome(), css=load_css(), fill_width=True, fill_height=True,) as gradio_ui: | |
# Adding a sliding navbar | |
with gr.Column(scale=1, elem_id='navbar'): | |
gr.Image( | |
'./logo.png', | |
interactive=False, | |
show_label=False, | |
scale=1, | |
width="50%", | |
height="50%" | |
) | |
with gr.Column(): | |
dora_chatbot_button = gr.Checkbox(label="Dora", value=True, elem_classes=["navbar-button"]) | |
document_workbench_button = gr.Checkbox(label="Published RTS documents", value=True, elem_classes=["navbar-button"]) | |
newsfeed_button = gr.Checkbox(label="Bafin documents", value=True, elem_classes=["navbar-button"]) | |
question_prompt = gr.Textbox( | |
value=random_prompt(), | |
label='What you always wanted to know about Dora:', | |
elem_classes=['textbox'], | |
lines=6 | |
) | |
with gr.Row(): | |
clear_results_button = gr.Button('Clear Results', variant='secondary', size="m") | |
submit_button = gr.Button('Submit', variant='primary', size="m") | |
# Adding a header | |
gr.Markdown("# The Doracle", elem_id="header") | |
gr.Markdown("----------------------------------------------------------------------------") | |
display_prompt = gr.Markdown( | |
value="", | |
label="question_prompt", | |
elem_id="header" | |
) | |
gr.Markdown("----------------------------------------------------------------------------") | |
with gr.Column(scale=1): | |
with gr.Row(elem_id='text_block'): | |
llm_generation = gr.Markdown(label="LLM Generation", elem_id="llm_generation") | |
gr.Markdown("----------------------------------------------------------------------------") | |
with gr.Row(elem_id='text_block'): | |
dora_documents = gr.Markdown(label="DORA Documents") | |
dora_rts_documents = gr.Markdown(label="DORA RTS Documents") | |
dora_news_documents = gr.Markdown(label="Bafin supporting Documents") | |
# Adding a footer with impressum and contact | |
with gr.Row(elem_classes="footer"): | |
gr.Markdown("Contact", elem_id="clickable_markdown") | |
invisible_btn = gr.Button("", elem_id="invisible_button") | |
gr.on( | |
triggers=[question_prompt.submit, submit_button.click], | |
inputs=[question_prompt], | |
outputs=[display_prompt, llm_generation, dora_documents, dora_rts_documents, dora_news_documents], | |
fn=show_loading | |
).then( | |
outputs=[llm_generation, dora_documents, dora_rts_documents, dora_news_documents], | |
inputs=[question_prompt, dora_chatbot_button, document_workbench_button, newsfeed_button], | |
fn=generate_response | |
) | |
# Use gr.on() with the invisible button's click event | |
gr.on( | |
triggers=[invisible_btn.click], | |
fn=on_click, | |
outputs=[llm_generation] | |
) | |
# Clearing out all results when the appropriate button is clicked | |
clear_results_button.click(fn=clear_results, outputs=[display_prompt, llm_generation, dora_documents, dora_rts_documents, dora_news_documents]) | |
gradio_ui.launch() | |
if __name__ == "__main__": | |
_set_env("OPENAI_API_KEY") | |
set_llm_cache(SQLiteCache(database_path=".cache.db")) | |
dora_retriever, dora_rts_retriever, dora_news_retriever = load_vectorstores( | |
["./dora_vectorstore_data_faiss.vst", | |
"./rts_eur_lex_vectorstore_faiss.vst", | |
"./bafin_news_vectorstore_faiss.vst",] | |
) | |
fast_llm = ChatOpenAI(model="gpt-3.5-turbo") | |
smart_llm = ChatOpenAI(model="gpt-4-turbo", temperature=0.2, max_tokens=4096) | |
tool_llm = ChatOpenAI(model="gpt-4o") | |
rewrite_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=1, cache=False) | |
dora_question_rewriter = IMPROVE_PROMPT | tool_llm | StrOutputParser() | |
retrieval_grader = RELEVANCE_PROMPT | fast_llm.with_structured_output(GradeDocuments) | |
answer_chain = ANSWER_PROMPT | tool_llm | StrOutputParser() #former RAG chain | |
hallucination_grader = HALLUCINATION_PROMPT | fast_llm.with_structured_output(GradeHallucinations) | |
answer_grader = RESOLVER_PROMPT | fast_llm.with_structured_output(GradeAnswer) | |
question_rewriter = REWRITER_PROMPT | rewrite_llm | StrOutputParser() | |
app = compile_graph() | |
# And finally, run the app | |
run_gradio() | |