TillLangbein commited on
Commit
1c3ef38
·
1 Parent(s): cfa680c

Reworked citation system

Browse files
Files changed (2) hide show
  1. app.py +105 -52
  2. prompts.py +23 -8
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import getpass
2
  import os
3
  import random
 
4
 
5
  from langchain_openai import ChatOpenAI
6
  from langchain_core.globals import set_llm_cache
 
7
  from langchain_community.cache import SQLiteCache
8
  from langchain_community.vectorstores import FAISS
9
  from langchain_openai import OpenAIEmbeddings
@@ -65,6 +67,14 @@ class GradeAnswer(BaseModel):
65
  description="Answer addresses the question, 'yes' or 'no'"
66
  )
67
 
 
 
 
 
 
 
 
 
68
  class GraphState(TypedDict):
69
  """
70
  Represents the state of our graph.
@@ -82,6 +92,7 @@ class GraphState(TypedDict):
82
  dora_docs: List[str]
83
  dora_rts_docs: List[str]
84
  dora_news_docs: List[str]
 
85
 
86
  def _set_env(var: str):
87
  if os.environ.get(var):
@@ -92,13 +103,13 @@ def load_vectorstores(paths: list):
92
  # The dora vectorstore
93
  embd = OpenAIEmbeddings()
94
  model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
95
- compressor = CrossEncoderReranker(model=model, top_n=7)
96
 
97
  vectorstores = [FAISS.load_local(path, embd, allow_dangerous_deserialization=True) for path in paths]
98
  base_retrievers = [vectorstore.as_retriever(search_type="mmr", search_kwargs={
99
- "k": 10,
100
- "fetch_k": 20,
101
- "score_threshold": 0.7,
102
  }) for vectorstore in vectorstores]
103
 
104
  retrievers = [ContextualCompressionRetriever(
@@ -106,7 +117,48 @@ def load_vectorstores(paths: list):
106
  ) for retriever in base_retrievers]
107
 
108
  return retrievers
109
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # Put all chains in fuctions
111
  def dora_rewrite(state):
112
  """
@@ -168,8 +220,14 @@ def generate(state):
168
  documents = state["documents"]
169
 
170
  # RAG generation
171
- generation = answer_chain.invoke({"context": documents, "question": question})
172
- return {"generation": generation}
 
 
 
 
 
 
173
 
174
  def transform_query(state):
175
  """
@@ -271,10 +329,7 @@ def grade_generation_v_documents_and_question(state):
271
  print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
272
  return "not useful"
273
  else:
274
- for document in documents:
275
- print(document.page_content)
276
  print("---DECISION: THOSE DOCUMENTS ARE NOT GROUNDING THIS GENERATION---")
277
- print(f"{generation = }")
278
  return "not supported"
279
 
280
  # Then compile the graph
@@ -308,7 +363,7 @@ def compile_graph():
308
  "generate",
309
  grade_generation_v_documents_and_question,
310
  {
311
- "not supported": "generate",
312
  "useful": END,
313
  "not useful": "transform_query",
314
  },
@@ -323,19 +378,20 @@ def generate_response(question: str, dora: bool, rts: bool, news: bool):
323
  state = app.invoke({"question": question, "selected_sources": selected_sources})
324
  return (
325
  state["generation"],
326
- ('\n\n'.join([f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_docs"]])) if "dora_docs" in state and state["dora_docs"] else 'No documents available.',
327
- ('\n\n'.join([f"***{doc.metadata['source']}, section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_rts_docs"]])) if "dora_rts_docs" in state and state["dora_rts_docs"] else 'No documents available.',
328
- ('\n\n'.join([f"***{doc.metadata['source']}***: {doc.page_content}" for doc in state["dora_news_docs"]])) if "dora_news_docs" in state and state["dora_news_docs"] else 'No documents available.',
 
329
  )
330
 
331
  def show_loading(prompt: str):
332
- return [prompt, "loading", "loading", "loading", "loading"]
333
 
334
  def on_click():
335
  return "I would love to hear your opinion: \[email protected]"
336
 
337
  def clear_results():
338
- return "", "", "", "", ""
339
 
340
  def random_prompt():
341
  return random.choice([
@@ -360,8 +416,31 @@ def load_css():
360
  with open('./style.css', 'r') as file:
361
  return file.read()
362
 
363
- def run_gradio():
364
- with gr.Blocks(title='Artificial Compliance', css=load_css(), fill_width=True, fill_height=True,) as gradio_ui:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  # theme=gr.themes.Monochrome(),
366
  # Adding a sliding navbar
367
  with gr.Column(scale=1, elem_id='navbar'):
@@ -401,11 +480,11 @@ def run_gradio():
401
  llm_generation = gr.Markdown(label="LLM Generation", elem_id="llm_generation")
402
 
403
  gr.Markdown("----------------------------------------------------------------------------")
 
 
 
404
 
405
- with gr.Row(elem_id='text_block'):
406
- dora_documents = gr.Markdown(label="DORA Documents")
407
- dora_rts_documents = gr.Markdown(label="DORA RTS Documents")
408
- dora_news_documents = gr.Markdown(label="Bafin supporting Documents")
409
 
410
  # Adding a footer with impressum and contact
411
  with gr.Row(elem_classes="footer"):
@@ -415,10 +494,10 @@ def run_gradio():
415
  gr.on(
416
  triggers=[question_prompt.submit, submit_button.click],
417
  inputs=[question_prompt],
418
- outputs=[display_prompt, llm_generation, dora_documents, dora_rts_documents, dora_news_documents],
419
  fn=show_loading
420
  ).then(
421
- outputs=[llm_generation, dora_documents, dora_rts_documents, dora_news_documents],
422
  inputs=[question_prompt, dora_chatbot_button, document_workbench_button, newsfeed_button],
423
  fn=generate_response
424
  )
@@ -431,35 +510,9 @@ def run_gradio():
431
  )
432
 
433
  # Clearing out all results when the appropriate button is clicked
434
- clear_results_button.click(fn=clear_results, outputs=[display_prompt, llm_generation, dora_documents, dora_rts_documents, dora_news_documents])
435
-
436
- gradio_ui.launch()
437
-
438
-
439
- if __name__ == "__main__":
440
- _set_env("OPENAI_API_KEY")
441
- set_llm_cache(SQLiteCache(database_path=".cache.db"))
442
-
443
- dora_retriever, dora_rts_retriever, dora_news_retriever = load_vectorstores(
444
- ["./dora_vectorstore_data_faiss.vst",
445
- "./rts_eur_lex_vectorstore_faiss.vst",
446
- "./bafin_news_vectorstore_faiss.vst",]
447
- )
448
-
449
- fast_llm = ChatOpenAI(model="gpt-3.5-turbo")
450
- tool_llm = ChatOpenAI(model="gpt-4o")
451
- rewrite_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=1, cache=False)
452
-
453
- dora_question_rewriter = IMPROVE_PROMPT | tool_llm | StrOutputParser()
454
- answer_chain = ANSWER_PROMPT | tool_llm | StrOutputParser()
455
- hallucination_grader = HALLUCINATION_PROMPT | fast_llm.with_structured_output(GradeHallucinations)
456
- answer_grader = RESOLVER_PROMPT | fast_llm.with_structured_output(GradeAnswer)
457
- question_rewriter = REWRITER_PROMPT | rewrite_llm | StrOutputParser()
458
-
459
- app = compile_graph()
460
 
461
- # And finally, run the app
462
- run_gradio()
463
 
464
 
465
 
 
1
  import getpass
2
  import os
3
  import random
4
+ import re
5
 
6
  from langchain_openai import ChatOpenAI
7
  from langchain_core.globals import set_llm_cache
8
+ from langchain_core.documents import Document
9
  from langchain_community.cache import SQLiteCache
10
  from langchain_community.vectorstores import FAISS
11
  from langchain_openai import OpenAIEmbeddings
 
67
  description="Answer addresses the question, 'yes' or 'no'"
68
  )
69
 
70
+ class AnswerWithCitations(BaseModel):
71
+ answer: str = Field(
72
+ description="Comprehensive answer to the user's question with citations.",
73
+ )
74
+ citations: List[str] = Field(
75
+ description="List of the first 20 characters of sources cited in the answer."
76
+ )
77
+
78
  class GraphState(TypedDict):
79
  """
80
  Represents the state of our graph.
 
92
  dora_docs: List[str]
93
  dora_rts_docs: List[str]
94
  dora_news_docs: List[str]
95
+ citations: List[str]
96
 
97
  def _set_env(var: str):
98
  if os.environ.get(var):
 
103
  # The dora vectorstore
104
  embd = OpenAIEmbeddings()
105
  model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
106
+ compressor = CrossEncoderReranker(model=model, top_n=4)
107
 
108
  vectorstores = [FAISS.load_local(path, embd, allow_dangerous_deserialization=True) for path in paths]
109
  base_retrievers = [vectorstore.as_retriever(search_type="mmr", search_kwargs={
110
+ "k": 7,
111
+ "fetch_k": 10,
112
+ "score_threshold": 0.8,
113
  }) for vectorstore in vectorstores]
114
 
115
  retrievers = [ContextualCompressionRetriever(
 
117
  ) for retriever in base_retrievers]
118
 
119
  return retrievers
120
+
121
+ def starts_with_ignoring_blanks(full_text, prefix):
122
+ # Normalize all types of blanks to regular spaces
123
+ normalized_full_text = re.sub(r'\s+', ' ', full_text.strip())
124
+ normalized_prefix = re.sub(r'\s+', ' ', prefix.strip())
125
+
126
+ # Check if the normalized full text starts with the normalized prefix
127
+ return normalized_full_text.startswith(normalized_prefix)
128
+
129
+ def match_citations_to_documents(citations: List[str], documents: List[Document]):
130
+ """
131
+ Matches the citations to the documents by searching for the source and section in the documents
132
+
133
+ Args:
134
+ citations (List[str]): List of citations to match
135
+ documents (List[Document]): List of documents to search in
136
+
137
+ Returns:
138
+ dict: Dictionary with the matched documents, where the key is the citation number and the value is the matched document
139
+ """
140
+ matched_documents = {}
141
+
142
+ for num, citation in enumerate(citations, 1):
143
+ # Extract the relevant parts from the citation (source and section)
144
+ print(f"checking the {num} citation: {citation}")
145
+ for doc in documents:
146
+ print(f"Does this: '{doc.page_content[:30]}' starts with this: '{citation}'?")
147
+ print(f"{doc.page_content[:40] =}")
148
+ print(f"{citation} =")
149
+ print(f"{doc.page_content[:40].startswith(citation) =}")
150
+ if starts_with_ignoring_blanks(doc.page_content[:40], citation): #Strangely, the 25 of the citation often become 35
151
+ print("yes")
152
+ if doc.metadata.get("section", None):
153
+ matched_documents[f"<sup>{num}</sup>"] = f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}"
154
+ else:
155
+ matched_documents[f"<sup>{num}</sup>"] = f"***{doc.metadata['source']}***: {doc.page_content}"
156
+ break
157
+ else:
158
+ print("no")
159
+
160
+ return matched_documents
161
+
162
  # Put all chains in fuctions
163
  def dora_rewrite(state):
164
  """
 
220
  documents = state["documents"]
221
 
222
  # RAG generation
223
+ answer = answer_chain.invoke({"context": documents, "question": question})
224
+
225
+ generation = answer.answer
226
+ print(f"{answer.citations = }")
227
+ citations = match_citations_to_documents(answer.citations, documents)
228
+ print(f"{len(citations)} found, is that correct?")
229
+
230
+ return {"generation": generation, "citations": citations}
231
 
232
  def transform_query(state):
233
  """
 
329
  print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
330
  return "not useful"
331
  else:
 
 
332
  print("---DECISION: THOSE DOCUMENTS ARE NOT GROUNDING THIS GENERATION---")
 
333
  return "not supported"
334
 
335
  # Then compile the graph
 
363
  "generate",
364
  grade_generation_v_documents_and_question,
365
  {
366
+ "not supported": "transform_query",
367
  "useful": END,
368
  "not useful": "transform_query",
369
  },
 
378
  state = app.invoke({"question": question, "selected_sources": selected_sources})
379
  return (
380
  state["generation"],
381
+ ('\n\n'.join([f"{num} - {doc}" for num, doc in state["citations"].items()])) if "citations" in state and state["citations"] else 'No citations available.',
382
+ # ('\n\n'.join([f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_docs"]])) if "dora_docs" in state and state["dora_docs"] else 'No documents available.',
383
+ # ('\n\n'.join([f"***{doc.metadata['source']}, section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_rts_docs"]])) if "dora_rts_docs" in state and state["dora_rts_docs"] else 'No documents available.',
384
+ # ('\n\n'.join([f"***{doc.metadata['source']}***: {doc.page_content}" for doc in state["dora_news_docs"]])) if "dora_news_docs" in state and state["dora_news_docs"] else 'No documents available.',
385
  )
386
 
387
  def show_loading(prompt: str):
388
+ return [prompt, "loading", "loading"]
389
 
390
  def on_click():
391
  return "I would love to hear your opinion: \[email protected]"
392
 
393
  def clear_results():
394
+ return "", "", ""
395
 
396
  def random_prompt():
397
  return random.choice([
 
416
  with open('./style.css', 'r') as file:
417
  return file.read()
418
 
419
+ if __name__ == "__main__":
420
+ _set_env("OPENAI_API_KEY")
421
+ set_llm_cache(SQLiteCache(database_path=".cache.db"))
422
+
423
+ dora_retriever, dora_rts_retriever, dora_news_retriever = load_vectorstores(
424
+ ["./dora_vectorstore_data_faiss.vst",
425
+ "./rts_eur_lex_vectorstore_faiss.vst",
426
+ "./bafin_news_vectorstore_faiss.vst",]
427
+ )
428
+
429
+ fast_llm = ChatOpenAI(model="gpt-3.5-turbo")
430
+ tool_llm = ChatOpenAI(model="gpt-4o")
431
+ rewrite_llm = ChatOpenAI(model="gpt-4o", temperature=1, cache=False)
432
+
433
+ dora_question_rewriter = IMPROVE_PROMPT | tool_llm | StrOutputParser()
434
+ answer_chain = ANSWER_PROMPT | tool_llm.with_structured_output(
435
+ AnswerWithCitations, include_raw=False
436
+ ).with_config(run_name="GenerateAnswer")
437
+ hallucination_grader = HALLUCINATION_PROMPT | fast_llm.with_structured_output(GradeHallucinations)
438
+ answer_grader = RESOLVER_PROMPT | fast_llm.with_structured_output(GradeAnswer)
439
+ question_rewriter = REWRITER_PROMPT | rewrite_llm | StrOutputParser()
440
+
441
+ app = compile_graph()
442
+
443
+ with gr.Blocks(title='Artificial Compliance', css=load_css(), fill_width=True, fill_height=True,) as demo:
444
  # theme=gr.themes.Monochrome(),
445
  # Adding a sliding navbar
446
  with gr.Column(scale=1, elem_id='navbar'):
 
480
  llm_generation = gr.Markdown(label="LLM Generation", elem_id="llm_generation")
481
 
482
  gr.Markdown("----------------------------------------------------------------------------")
483
+
484
+ with gr.Row(elem_id='text_block'):
485
+ citations = gr.Markdown(label="citations", elem_id="llm_generation")
486
 
487
+ gr.Markdown("----------------------------------------------------------------------------")
 
 
 
488
 
489
  # Adding a footer with impressum and contact
490
  with gr.Row(elem_classes="footer"):
 
494
  gr.on(
495
  triggers=[question_prompt.submit, submit_button.click],
496
  inputs=[question_prompt],
497
+ outputs=[display_prompt, llm_generation, citations],
498
  fn=show_loading
499
  ).then(
500
+ outputs=[llm_generation, citations],
501
  inputs=[question_prompt, dora_chatbot_button, document_workbench_button, newsfeed_button],
502
  fn=generate_response
503
  )
 
510
  )
511
 
512
  # Clearing out all results when the appropriate button is clicked
513
+ clear_results_button.click(fn=clear_results, outputs=[display_prompt, llm_generation, citations])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
+ demo.launch()
 
516
 
517
 
518
 
prompts.py CHANGED
@@ -20,14 +20,29 @@ ANSWER_PROMPT = ChatPromptTemplate.from_messages(
20
  [
21
  (
22
  "system",
23
- "You are a highly experienced IT auditor, specializing in information security and regulatory compliance. "
24
- "Your task is to assist a colleague who has approached you with a question. "
25
- "You have access to relevant context, provided here: {context}. "
26
- "Make your response as informative as possible and make sure every sentence is supported by the provided context."
27
- "Each information must be backed up by a citation from at least one of the information sources in the context, formatted as a footnote, reproducing the source after your response."
28
- "Your answer should be structured and suitable for regulatory documentation or audit reporting. "
29
- "If you do not have a citation from the provided source material in the message, explicitly state: 'No citations found.' Never generate a citation if no source material is provided."
30
- "Ensure all relevant details from the context are included in your response."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ),
32
  ("user", "{question}"),
33
  ]
 
20
  [
21
  (
22
  "system",
23
+ """You are an experienced IT auditor specializing in information security and regulatory compliance.
24
+ Your task is to assist a colleague who has a question. You have access to the following context: {context}.
25
+ Ensure your response is comprehensive and as many information from the context as possible are included.
26
+ Strive to include citations from as many different documents as relevant.
27
+ Make your response as informative as possible and make sure every sentence is supported by the provided information.
28
+ Each claim in the response must be backed up by a citation from at least one of the information sources.
29
+ Each citation should be the first 20 characters from the source content used.
30
+ If you do not have a citation from the provided source material in the message, explicitly state: 'No citations found.' Never generate a citation if no source material is provided.
31
+
32
+ Example Answer:
33
+ Deploying a Security Information and Event Management (SIEM) system with Extended Detection and Response (XDR) is ok <sup>[1]</sup>. But it is not ok to deploy a SIEM system with Extended Incident Management (XIM) <sup>[^2]</sup>.
34
+
35
+ Example Footnotes:
36
+ [^1]: "Article\xa08Identification1."
37
+ [^2]: "Article\xa029Preliminary ass"
38
+
39
+ Example Answer 2:
40
+ The Digital Operational Resilience Act (DORA) outlines several key requirements and obligations for ICT risk management within financial entities <sup>[1]</sup>. One of the primary obligations is the implementation of ICT security policies <sup>[^2]</sup>.
41
+
42
+ Example Footnotes 2:
43
+ [^1]: "the implementation of the"
44
+ [^2]: "(EU) 2022/2554;(i)the cla"
45
+ """
46
  ),
47
  ("user", "{question}"),
48
  ]