Spaces:
Sleeping
Sleeping
Commit
·
1c3ef38
1
Parent(s):
cfa680c
Reworked citation system
Browse files- app.py +105 -52
- prompts.py +23 -8
app.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import getpass
|
2 |
import os
|
3 |
import random
|
|
|
4 |
|
5 |
from langchain_openai import ChatOpenAI
|
6 |
from langchain_core.globals import set_llm_cache
|
|
|
7 |
from langchain_community.cache import SQLiteCache
|
8 |
from langchain_community.vectorstores import FAISS
|
9 |
from langchain_openai import OpenAIEmbeddings
|
@@ -65,6 +67,14 @@ class GradeAnswer(BaseModel):
|
|
65 |
description="Answer addresses the question, 'yes' or 'no'"
|
66 |
)
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
class GraphState(TypedDict):
|
69 |
"""
|
70 |
Represents the state of our graph.
|
@@ -82,6 +92,7 @@ class GraphState(TypedDict):
|
|
82 |
dora_docs: List[str]
|
83 |
dora_rts_docs: List[str]
|
84 |
dora_news_docs: List[str]
|
|
|
85 |
|
86 |
def _set_env(var: str):
|
87 |
if os.environ.get(var):
|
@@ -92,13 +103,13 @@ def load_vectorstores(paths: list):
|
|
92 |
# The dora vectorstore
|
93 |
embd = OpenAIEmbeddings()
|
94 |
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
|
95 |
-
compressor = CrossEncoderReranker(model=model, top_n=
|
96 |
|
97 |
vectorstores = [FAISS.load_local(path, embd, allow_dangerous_deserialization=True) for path in paths]
|
98 |
base_retrievers = [vectorstore.as_retriever(search_type="mmr", search_kwargs={
|
99 |
-
"k":
|
100 |
-
"fetch_k":
|
101 |
-
"score_threshold": 0.
|
102 |
}) for vectorstore in vectorstores]
|
103 |
|
104 |
retrievers = [ContextualCompressionRetriever(
|
@@ -106,7 +117,48 @@ def load_vectorstores(paths: list):
|
|
106 |
) for retriever in base_retrievers]
|
107 |
|
108 |
return retrievers
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
# Put all chains in fuctions
|
111 |
def dora_rewrite(state):
|
112 |
"""
|
@@ -168,8 +220,14 @@ def generate(state):
|
|
168 |
documents = state["documents"]
|
169 |
|
170 |
# RAG generation
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
def transform_query(state):
|
175 |
"""
|
@@ -271,10 +329,7 @@ def grade_generation_v_documents_and_question(state):
|
|
271 |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
272 |
return "not useful"
|
273 |
else:
|
274 |
-
for document in documents:
|
275 |
-
print(document.page_content)
|
276 |
print("---DECISION: THOSE DOCUMENTS ARE NOT GROUNDING THIS GENERATION---")
|
277 |
-
print(f"{generation = }")
|
278 |
return "not supported"
|
279 |
|
280 |
# Then compile the graph
|
@@ -308,7 +363,7 @@ def compile_graph():
|
|
308 |
"generate",
|
309 |
grade_generation_v_documents_and_question,
|
310 |
{
|
311 |
-
"not supported": "
|
312 |
"useful": END,
|
313 |
"not useful": "transform_query",
|
314 |
},
|
@@ -323,19 +378,20 @@ def generate_response(question: str, dora: bool, rts: bool, news: bool):
|
|
323 |
state = app.invoke({"question": question, "selected_sources": selected_sources})
|
324 |
return (
|
325 |
state["generation"],
|
326 |
-
('\n\n'.join([f"
|
327 |
-
('\n\n'.join([f"***{doc.metadata['source']}
|
328 |
-
('\n\n'.join([f"***{doc.metadata['source']}***: {doc.page_content}" for doc in state["
|
|
|
329 |
)
|
330 |
|
331 |
def show_loading(prompt: str):
|
332 |
-
return [prompt, "loading", "loading"
|
333 |
|
334 |
def on_click():
|
335 |
return "I would love to hear your opinion: \[email protected]"
|
336 |
|
337 |
def clear_results():
|
338 |
-
return "", "", ""
|
339 |
|
340 |
def random_prompt():
|
341 |
return random.choice([
|
@@ -360,8 +416,31 @@ def load_css():
|
|
360 |
with open('./style.css', 'r') as file:
|
361 |
return file.read()
|
362 |
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
# theme=gr.themes.Monochrome(),
|
366 |
# Adding a sliding navbar
|
367 |
with gr.Column(scale=1, elem_id='navbar'):
|
@@ -401,11 +480,11 @@ def run_gradio():
|
|
401 |
llm_generation = gr.Markdown(label="LLM Generation", elem_id="llm_generation")
|
402 |
|
403 |
gr.Markdown("----------------------------------------------------------------------------")
|
|
|
|
|
|
|
404 |
|
405 |
-
|
406 |
-
dora_documents = gr.Markdown(label="DORA Documents")
|
407 |
-
dora_rts_documents = gr.Markdown(label="DORA RTS Documents")
|
408 |
-
dora_news_documents = gr.Markdown(label="Bafin supporting Documents")
|
409 |
|
410 |
# Adding a footer with impressum and contact
|
411 |
with gr.Row(elem_classes="footer"):
|
@@ -415,10 +494,10 @@ def run_gradio():
|
|
415 |
gr.on(
|
416 |
triggers=[question_prompt.submit, submit_button.click],
|
417 |
inputs=[question_prompt],
|
418 |
-
outputs=[display_prompt, llm_generation,
|
419 |
fn=show_loading
|
420 |
).then(
|
421 |
-
outputs=[llm_generation,
|
422 |
inputs=[question_prompt, dora_chatbot_button, document_workbench_button, newsfeed_button],
|
423 |
fn=generate_response
|
424 |
)
|
@@ -431,35 +510,9 @@ def run_gradio():
|
|
431 |
)
|
432 |
|
433 |
# Clearing out all results when the appropriate button is clicked
|
434 |
-
clear_results_button.click(fn=clear_results, outputs=[display_prompt, llm_generation,
|
435 |
-
|
436 |
-
gradio_ui.launch()
|
437 |
-
|
438 |
-
|
439 |
-
if __name__ == "__main__":
|
440 |
-
_set_env("OPENAI_API_KEY")
|
441 |
-
set_llm_cache(SQLiteCache(database_path=".cache.db"))
|
442 |
-
|
443 |
-
dora_retriever, dora_rts_retriever, dora_news_retriever = load_vectorstores(
|
444 |
-
["./dora_vectorstore_data_faiss.vst",
|
445 |
-
"./rts_eur_lex_vectorstore_faiss.vst",
|
446 |
-
"./bafin_news_vectorstore_faiss.vst",]
|
447 |
-
)
|
448 |
-
|
449 |
-
fast_llm = ChatOpenAI(model="gpt-3.5-turbo")
|
450 |
-
tool_llm = ChatOpenAI(model="gpt-4o")
|
451 |
-
rewrite_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=1, cache=False)
|
452 |
-
|
453 |
-
dora_question_rewriter = IMPROVE_PROMPT | tool_llm | StrOutputParser()
|
454 |
-
answer_chain = ANSWER_PROMPT | tool_llm | StrOutputParser()
|
455 |
-
hallucination_grader = HALLUCINATION_PROMPT | fast_llm.with_structured_output(GradeHallucinations)
|
456 |
-
answer_grader = RESOLVER_PROMPT | fast_llm.with_structured_output(GradeAnswer)
|
457 |
-
question_rewriter = REWRITER_PROMPT | rewrite_llm | StrOutputParser()
|
458 |
-
|
459 |
-
app = compile_graph()
|
460 |
|
461 |
-
|
462 |
-
run_gradio()
|
463 |
|
464 |
|
465 |
|
|
|
1 |
import getpass
|
2 |
import os
|
3 |
import random
|
4 |
+
import re
|
5 |
|
6 |
from langchain_openai import ChatOpenAI
|
7 |
from langchain_core.globals import set_llm_cache
|
8 |
+
from langchain_core.documents import Document
|
9 |
from langchain_community.cache import SQLiteCache
|
10 |
from langchain_community.vectorstores import FAISS
|
11 |
from langchain_openai import OpenAIEmbeddings
|
|
|
67 |
description="Answer addresses the question, 'yes' or 'no'"
|
68 |
)
|
69 |
|
70 |
+
class AnswerWithCitations(BaseModel):
|
71 |
+
answer: str = Field(
|
72 |
+
description="Comprehensive answer to the user's question with citations.",
|
73 |
+
)
|
74 |
+
citations: List[str] = Field(
|
75 |
+
description="List of the first 20 characters of sources cited in the answer."
|
76 |
+
)
|
77 |
+
|
78 |
class GraphState(TypedDict):
|
79 |
"""
|
80 |
Represents the state of our graph.
|
|
|
92 |
dora_docs: List[str]
|
93 |
dora_rts_docs: List[str]
|
94 |
dora_news_docs: List[str]
|
95 |
+
citations: List[str]
|
96 |
|
97 |
def _set_env(var: str):
|
98 |
if os.environ.get(var):
|
|
|
103 |
# The dora vectorstore
|
104 |
embd = OpenAIEmbeddings()
|
105 |
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
|
106 |
+
compressor = CrossEncoderReranker(model=model, top_n=4)
|
107 |
|
108 |
vectorstores = [FAISS.load_local(path, embd, allow_dangerous_deserialization=True) for path in paths]
|
109 |
base_retrievers = [vectorstore.as_retriever(search_type="mmr", search_kwargs={
|
110 |
+
"k": 7,
|
111 |
+
"fetch_k": 10,
|
112 |
+
"score_threshold": 0.8,
|
113 |
}) for vectorstore in vectorstores]
|
114 |
|
115 |
retrievers = [ContextualCompressionRetriever(
|
|
|
117 |
) for retriever in base_retrievers]
|
118 |
|
119 |
return retrievers
|
120 |
+
|
121 |
+
def starts_with_ignoring_blanks(full_text, prefix):
|
122 |
+
# Normalize all types of blanks to regular spaces
|
123 |
+
normalized_full_text = re.sub(r'\s+', ' ', full_text.strip())
|
124 |
+
normalized_prefix = re.sub(r'\s+', ' ', prefix.strip())
|
125 |
+
|
126 |
+
# Check if the normalized full text starts with the normalized prefix
|
127 |
+
return normalized_full_text.startswith(normalized_prefix)
|
128 |
+
|
129 |
+
def match_citations_to_documents(citations: List[str], documents: List[Document]):
|
130 |
+
"""
|
131 |
+
Matches the citations to the documents by searching for the source and section in the documents
|
132 |
+
|
133 |
+
Args:
|
134 |
+
citations (List[str]): List of citations to match
|
135 |
+
documents (List[Document]): List of documents to search in
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
dict: Dictionary with the matched documents, where the key is the citation number and the value is the matched document
|
139 |
+
"""
|
140 |
+
matched_documents = {}
|
141 |
+
|
142 |
+
for num, citation in enumerate(citations, 1):
|
143 |
+
# Extract the relevant parts from the citation (source and section)
|
144 |
+
print(f"checking the {num} citation: {citation}")
|
145 |
+
for doc in documents:
|
146 |
+
print(f"Does this: '{doc.page_content[:30]}' starts with this: '{citation}'?")
|
147 |
+
print(f"{doc.page_content[:40] =}")
|
148 |
+
print(f"{citation} =")
|
149 |
+
print(f"{doc.page_content[:40].startswith(citation) =}")
|
150 |
+
if starts_with_ignoring_blanks(doc.page_content[:40], citation): #Strangely, the 25 of the citation often become 35
|
151 |
+
print("yes")
|
152 |
+
if doc.metadata.get("section", None):
|
153 |
+
matched_documents[f"<sup>{num}</sup>"] = f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}"
|
154 |
+
else:
|
155 |
+
matched_documents[f"<sup>{num}</sup>"] = f"***{doc.metadata['source']}***: {doc.page_content}"
|
156 |
+
break
|
157 |
+
else:
|
158 |
+
print("no")
|
159 |
+
|
160 |
+
return matched_documents
|
161 |
+
|
162 |
# Put all chains in fuctions
|
163 |
def dora_rewrite(state):
|
164 |
"""
|
|
|
220 |
documents = state["documents"]
|
221 |
|
222 |
# RAG generation
|
223 |
+
answer = answer_chain.invoke({"context": documents, "question": question})
|
224 |
+
|
225 |
+
generation = answer.answer
|
226 |
+
print(f"{answer.citations = }")
|
227 |
+
citations = match_citations_to_documents(answer.citations, documents)
|
228 |
+
print(f"{len(citations)} found, is that correct?")
|
229 |
+
|
230 |
+
return {"generation": generation, "citations": citations}
|
231 |
|
232 |
def transform_query(state):
|
233 |
"""
|
|
|
329 |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
330 |
return "not useful"
|
331 |
else:
|
|
|
|
|
332 |
print("---DECISION: THOSE DOCUMENTS ARE NOT GROUNDING THIS GENERATION---")
|
|
|
333 |
return "not supported"
|
334 |
|
335 |
# Then compile the graph
|
|
|
363 |
"generate",
|
364 |
grade_generation_v_documents_and_question,
|
365 |
{
|
366 |
+
"not supported": "transform_query",
|
367 |
"useful": END,
|
368 |
"not useful": "transform_query",
|
369 |
},
|
|
|
378 |
state = app.invoke({"question": question, "selected_sources": selected_sources})
|
379 |
return (
|
380 |
state["generation"],
|
381 |
+
('\n\n'.join([f"{num} - {doc}" for num, doc in state["citations"].items()])) if "citations" in state and state["citations"] else 'No citations available.',
|
382 |
+
# ('\n\n'.join([f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_docs"]])) if "dora_docs" in state and state["dora_docs"] else 'No documents available.',
|
383 |
+
# ('\n\n'.join([f"***{doc.metadata['source']}, section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_rts_docs"]])) if "dora_rts_docs" in state and state["dora_rts_docs"] else 'No documents available.',
|
384 |
+
# ('\n\n'.join([f"***{doc.metadata['source']}***: {doc.page_content}" for doc in state["dora_news_docs"]])) if "dora_news_docs" in state and state["dora_news_docs"] else 'No documents available.',
|
385 |
)
|
386 |
|
387 |
def show_loading(prompt: str):
|
388 |
+
return [prompt, "loading", "loading"]
|
389 |
|
390 |
def on_click():
|
391 |
return "I would love to hear your opinion: \[email protected]"
|
392 |
|
393 |
def clear_results():
|
394 |
+
return "", "", ""
|
395 |
|
396 |
def random_prompt():
|
397 |
return random.choice([
|
|
|
416 |
with open('./style.css', 'r') as file:
|
417 |
return file.read()
|
418 |
|
419 |
+
if __name__ == "__main__":
|
420 |
+
_set_env("OPENAI_API_KEY")
|
421 |
+
set_llm_cache(SQLiteCache(database_path=".cache.db"))
|
422 |
+
|
423 |
+
dora_retriever, dora_rts_retriever, dora_news_retriever = load_vectorstores(
|
424 |
+
["./dora_vectorstore_data_faiss.vst",
|
425 |
+
"./rts_eur_lex_vectorstore_faiss.vst",
|
426 |
+
"./bafin_news_vectorstore_faiss.vst",]
|
427 |
+
)
|
428 |
+
|
429 |
+
fast_llm = ChatOpenAI(model="gpt-3.5-turbo")
|
430 |
+
tool_llm = ChatOpenAI(model="gpt-4o")
|
431 |
+
rewrite_llm = ChatOpenAI(model="gpt-4o", temperature=1, cache=False)
|
432 |
+
|
433 |
+
dora_question_rewriter = IMPROVE_PROMPT | tool_llm | StrOutputParser()
|
434 |
+
answer_chain = ANSWER_PROMPT | tool_llm.with_structured_output(
|
435 |
+
AnswerWithCitations, include_raw=False
|
436 |
+
).with_config(run_name="GenerateAnswer")
|
437 |
+
hallucination_grader = HALLUCINATION_PROMPT | fast_llm.with_structured_output(GradeHallucinations)
|
438 |
+
answer_grader = RESOLVER_PROMPT | fast_llm.with_structured_output(GradeAnswer)
|
439 |
+
question_rewriter = REWRITER_PROMPT | rewrite_llm | StrOutputParser()
|
440 |
+
|
441 |
+
app = compile_graph()
|
442 |
+
|
443 |
+
with gr.Blocks(title='Artificial Compliance', css=load_css(), fill_width=True, fill_height=True,) as demo:
|
444 |
# theme=gr.themes.Monochrome(),
|
445 |
# Adding a sliding navbar
|
446 |
with gr.Column(scale=1, elem_id='navbar'):
|
|
|
480 |
llm_generation = gr.Markdown(label="LLM Generation", elem_id="llm_generation")
|
481 |
|
482 |
gr.Markdown("----------------------------------------------------------------------------")
|
483 |
+
|
484 |
+
with gr.Row(elem_id='text_block'):
|
485 |
+
citations = gr.Markdown(label="citations", elem_id="llm_generation")
|
486 |
|
487 |
+
gr.Markdown("----------------------------------------------------------------------------")
|
|
|
|
|
|
|
488 |
|
489 |
# Adding a footer with impressum and contact
|
490 |
with gr.Row(elem_classes="footer"):
|
|
|
494 |
gr.on(
|
495 |
triggers=[question_prompt.submit, submit_button.click],
|
496 |
inputs=[question_prompt],
|
497 |
+
outputs=[display_prompt, llm_generation, citations],
|
498 |
fn=show_loading
|
499 |
).then(
|
500 |
+
outputs=[llm_generation, citations],
|
501 |
inputs=[question_prompt, dora_chatbot_button, document_workbench_button, newsfeed_button],
|
502 |
fn=generate_response
|
503 |
)
|
|
|
510 |
)
|
511 |
|
512 |
# Clearing out all results when the appropriate button is clicked
|
513 |
+
clear_results_button.click(fn=clear_results, outputs=[display_prompt, llm_generation, citations])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
|
515 |
+
demo.launch()
|
|
|
516 |
|
517 |
|
518 |
|
prompts.py
CHANGED
@@ -20,14 +20,29 @@ ANSWER_PROMPT = ChatPromptTemplate.from_messages(
|
|
20 |
[
|
21 |
(
|
22 |
"system",
|
23 |
-
"You are
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
),
|
32 |
("user", "{question}"),
|
33 |
]
|
|
|
20 |
[
|
21 |
(
|
22 |
"system",
|
23 |
+
"""You are an experienced IT auditor specializing in information security and regulatory compliance.
|
24 |
+
Your task is to assist a colleague who has a question. You have access to the following context: {context}.
|
25 |
+
Ensure your response is comprehensive and as many information from the context as possible are included.
|
26 |
+
Strive to include citations from as many different documents as relevant.
|
27 |
+
Make your response as informative as possible and make sure every sentence is supported by the provided information.
|
28 |
+
Each claim in the response must be backed up by a citation from at least one of the information sources.
|
29 |
+
Each citation should be the first 20 characters from the source content used.
|
30 |
+
If you do not have a citation from the provided source material in the message, explicitly state: 'No citations found.' Never generate a citation if no source material is provided.
|
31 |
+
|
32 |
+
Example Answer:
|
33 |
+
Deploying a Security Information and Event Management (SIEM) system with Extended Detection and Response (XDR) is ok <sup>[1]</sup>. But it is not ok to deploy a SIEM system with Extended Incident Management (XIM) <sup>[^2]</sup>.
|
34 |
+
|
35 |
+
Example Footnotes:
|
36 |
+
[^1]: "Article\xa08Identification1."
|
37 |
+
[^2]: "Article\xa029Preliminary ass"
|
38 |
+
|
39 |
+
Example Answer 2:
|
40 |
+
The Digital Operational Resilience Act (DORA) outlines several key requirements and obligations for ICT risk management within financial entities <sup>[1]</sup>. One of the primary obligations is the implementation of ICT security policies <sup>[^2]</sup>.
|
41 |
+
|
42 |
+
Example Footnotes 2:
|
43 |
+
[^1]: "the implementation of the"
|
44 |
+
[^2]: "(EU) 2022/2554;(i)the cla"
|
45 |
+
"""
|
46 |
),
|
47 |
("user", "{question}"),
|
48 |
]
|