minko186 commited on
Commit
e76dfe8
·
1 Parent(s): 95168db

add inline citations + page content

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. ai_generate.py +109 -6
  3. app.py +30 -25
.gitignore CHANGED
@@ -4,4 +4,5 @@ nohup.out
4
  *.out
5
  *.log
6
  *.json
 
7
  temp.py
 
4
  *.out
5
  *.log
6
  *.json
7
+ *.pdf
8
  temp.py
ai_generate.py CHANGED
@@ -15,6 +15,8 @@ from langchain_openai import ChatOpenAI
15
  from langchain_google_genai import ChatGoogleGenerativeAI
16
  from langchain_anthropic import ChatAnthropic
17
  from dotenv import load_dotenv
 
 
18
 
19
  load_dotenv()
20
 
@@ -47,6 +49,99 @@ llm_classes = {
47
  }
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
51
  model_name = llm_model_translation.get(model)
52
  llm_class = llm_classes.get(model_name)
@@ -108,15 +203,23 @@ def generate_rag(
108
  rag_prompt = hub.pull("rlm/rag-prompt")
109
 
110
  def format_docs(docs):
111
- return "\n\n".join(doc.page_content for doc in docs)
 
 
 
112
 
113
  docs = retriever.get_relevant_documents(topic)
114
- formatted_docs = format_docs(docs)
 
115
  rag_chain = (
116
- {"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser()
 
 
 
117
  )
118
- return rag_chain.invoke(prompt)
119
-
 
120
 
121
  def generate_base(
122
  prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
@@ -147,4 +250,4 @@ def generate(
147
  if path or url_content:
148
  return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
149
  else:
150
- return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
 
15
  from langchain_google_genai import ChatGoogleGenerativeAI
16
  from langchain_anthropic import ChatAnthropic
17
  from dotenv import load_dotenv
18
+ from langchain_core.output_parsers import XMLOutputParser
19
+ from langchain.prompts import ChatPromptTemplate
20
 
21
  load_dotenv()
22
 
 
49
  }
50
 
51
 
52
+ xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, fulfill all the requirements \
53
+ of the prompt and provide citations. If a chunk of the generated text does not use any of the sources (for example, \
54
+ introductions or general text), don't put a citation for that chunk and just leave citations empty. Otherwise, \
55
+ list all sources used for that chunk of the text. Don't add inline citations in the text itself. Add all citations to the separated \
56
+ citations section. Use explicit new lines in the text to show paragraph splits. \
57
+ Return a citation for every quote across all articles that justify the text. Use the following format for your final output:
58
+ <cited_text>
59
+ <chunk>
60
+ <text></text>
61
+ <citations>
62
+ <citation><source_id></source_id></citation>
63
+ ...
64
+ </citations>
65
+ </chunk>
66
+ <chunk>
67
+ <text></text>
68
+ <citations>
69
+ <citation><source_id></source_id></citation>
70
+ ...
71
+ </citations>
72
+ </chunk>
73
+ ...
74
+ </cited_text>
75
+ The entire text should be wrapped in one cited_text. For References section (if asked by prompt), don't add citations.
76
+ For source id, give a valid integer alone without a key.
77
+ Here are the sources:{context}"""
78
+ xml_prompt = ChatPromptTemplate.from_messages(
79
+ [("system", xml_system), ("human", "{input}")]
80
+ )
81
+
82
+ def format_docs_xml(docs: list[Document]) -> str:
83
+ formatted = []
84
+ for i, doc in enumerate(docs):
85
+ doc_str = f"""\
86
+ <source id=\"{i}\">
87
+ <path>{doc.metadata['source']}</path>
88
+ <article_snippet>{doc.page_content}</article_snippet>
89
+ </source>"""
90
+ formatted.append(doc_str)
91
+ return "\n\n<sources>" + "\n".join(formatted) + "</sources>"
92
+
93
+
94
+ def get_doc_content(docs, id):
95
+ return docs[id].page_content
96
+
97
+
98
+ def process_cited_text(data, docs):
99
+ # Initialize variables for the combined text and a dictionary for citations
100
+ combined_text = ""
101
+ citations = {}
102
+ # Iterate through the cited_text list
103
+ for item in data['cited_text']:
104
+ chunk_text = item['chunk'][0]['text']
105
+ combined_text += chunk_text
106
+ citation_ids = []
107
+ # Process the citations for the chunk
108
+ if item['chunk'][1]['citations']:
109
+ for c in item['chunk'][1]['citations']:
110
+ if c and 'citation' in c:
111
+ citation = c['citation']
112
+ if isinstance(citation, dict) and "source_id" in citation:
113
+ citation = citation['source_id']
114
+ if isinstance(citation, str):
115
+ try:
116
+ citation_ids.append(int(citation))
117
+ except ValueError:
118
+ pass # Handle cases where the string is not a valid integer
119
+ if citation_ids:
120
+ citation_texts = [f"<{cid}-{docs[cid].metadata['source']}>" for cid in citation_ids]
121
+ combined_text += " " + " ".join(citation_texts)
122
+ combined_text += "\n\n"
123
+ # Store unique citations in a dictionary
124
+ for citation_id in citation_ids:
125
+ if citation_id not in citations:
126
+ citations[citation_id] = {'source': docs[citation_id].metadata['source'], 'content': docs[citation_id].page_content}
127
+
128
+ return combined_text.strip(), citations
129
+
130
+
131
+ def citations_to_html(citations):
132
+ # Generate the HTML for the unique citations
133
+ html_content = ""
134
+ for citation_id, citation_info in citations.items():
135
+ html_content += (
136
+ f"<li><strong>Source ID:</strong> {citation_id}<br>"
137
+ f"<strong>Path:</strong> {citation_info['source']}<br>"
138
+ f"<strong>Page Content:</strong> {citation_info['content']}</li>"
139
+ )
140
+ html_content += "</ul></body></html>"
141
+
142
+ return html_content
143
+
144
+
145
  def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
146
  model_name = llm_model_translation.get(model)
147
  llm_class = llm_classes.get(model_name)
 
203
  rag_prompt = hub.pull("rlm/rag-prompt")
204
 
205
  def format_docs(docs):
206
+ if all(isinstance(doc, Document) for doc in docs):
207
+ return "\n\n".join(doc.page_content for doc in docs)
208
+ else:
209
+ raise TypeError("All items in docs must be instances of Document.")
210
 
211
  docs = retriever.get_relevant_documents(topic)
212
+
213
+ formatted_docs = format_docs_xml(docs)
214
  rag_chain = (
215
+ RunnablePassthrough.assign(context=lambda _: formatted_docs)
216
+ | xml_prompt
217
+ | llm
218
+ | XMLOutputParser()
219
  )
220
+ result = rag_chain.invoke({"input": prompt})
221
+ text, citations = process_cited_text(result, docs)
222
+ return text, citations
223
 
224
  def generate_base(
225
  prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
 
250
  if path or url_content:
251
  return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
252
  else:
253
+ return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
app.py CHANGED
@@ -19,7 +19,7 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipe
19
  from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
20
  from google_search import google_search, months, domain_list, build_date
21
  from humanize import humanize_text, device
22
- from ai_generate import generate
23
 
24
  print(f"Using device: {device}")
25
 
@@ -259,12 +259,16 @@ def ai_check(text: str, option: str):
259
 
260
 
261
  def generate_prompt(settings: Dict[str, str]) -> str:
 
262
  prompt = f"""
263
- I am a {settings['role']}
264
- Write a {settings['article_length']} words (around) {settings['format']} on {settings['topic']}.
 
 
265
  Context:
266
  - {settings['context']}
267
-
 
268
  Style and Tone:
269
  - Writing style: {settings['writing_style']}
270
  - Tone: {settings['tone']}
@@ -273,21 +277,20 @@ def generate_prompt(settings: Dict[str, str]) -> str:
273
  Content:
274
  - Depth: {settings['depth_of_content']}
275
  - Structure: {', '.join(settings['structure'])}
276
-
 
 
277
  Keywords to incorporate:
278
  {', '.join(settings['keywords'])}
279
-
 
280
  Additional requirements:
281
  - Don't start with "Here is a...", start with the requested text directly
282
- - Include {settings['num_examples']} relevant examples or case studies
283
- - Incorporate data or statistics from {', '.join(settings['references'])}
284
  - End with a {settings['conclusion_type']} conclusion
285
- - Add a "References" section in the format "References:" on a new line at the end with at least 3 credible detailed sources, formatted as [1], [2], etc. with each source on their own line
286
- - Do not repeat sources
287
  - Do not make any headline, title bold.
288
-
289
- Ensure proper paragraph breaks for better readability.
290
- Avoid any references to artificial intelligence, language models, or the fact that this is generated by an AI, and do not mention something like here is the article etc.
291
  """
292
  return prompt
293
 
@@ -361,19 +364,19 @@ def generate_article(
361
  prompt = generate_prompt(settings)
362
 
363
  print("Generated Prompt...\n", prompt)
364
- article = generate(
365
  prompt=prompt,
366
  topic=topic,
367
  model=ai_model,
368
  url_content=url_content,
369
  path=pdf_file_input,
 
370
  temperature=1,
371
  max_length=2048,
372
  api_key=api_key,
373
  sys_message="",
374
  )
375
-
376
- return clean_text(article)
377
 
378
 
379
  def get_history(history):
@@ -571,7 +574,7 @@ def generate_and_format(
571
  print(f"Google Search Query: {final_query}")
572
  url_content = google_search(final_query, sorted_date, domains_to_include)
573
  topic_context = topic + ", " + context
574
- article = generate_article(
575
  input_role,
576
  topic_context,
577
  context,
@@ -629,7 +632,7 @@ def generate_and_format(
629
  )
630
  print(save_message)
631
 
632
- return reference_formatted, history
633
 
634
 
635
  def create_interface():
@@ -857,6 +860,8 @@ def create_interface():
857
  with gr.Column(scale=3):
858
  with gr.Tab("Text Generator"):
859
  output_article = gr.Textbox(label="Generated Article", lines=20)
 
 
860
  ai_comments = gr.Textbox(
861
  label="Add comments to help edit generated text", interactive=True, visible=False
862
  )
@@ -966,7 +971,7 @@ def create_interface():
966
  pdf_file_input,
967
  history,
968
  ],
969
- outputs=[output_article, history],
970
  )
971
 
972
  regenerate_btn.click(
@@ -1003,7 +1008,7 @@ def create_interface():
1003
  exclude_sites,
1004
  ai_comments,
1005
  ],
1006
- outputs=[output_article, history],
1007
  )
1008
 
1009
  ai_check_btn.click(
@@ -1035,8 +1040,8 @@ def create_interface():
1035
 
1036
  if __name__ == "__main__":
1037
  demo = create_interface()
1038
- demo.queue(
1039
- max_size=2,
1040
- default_concurrency_limit=2,
1041
- ).launch(server_name="0.0.0.0", share=True, server_port=7890)
1042
- # demo.launch(server_name="0.0.0.0")
 
19
  from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
20
  from google_search import google_search, months, domain_list, build_date
21
  from humanize import humanize_text, device
22
+ from ai_generate import generate, citations_to_html
23
 
24
  print(f"Using device: {device}")
25
 
 
259
 
260
 
261
  def generate_prompt(settings: Dict[str, str]) -> str:
262
+ settings['keywords'] = [item for item in settings['keywords'] if item.strip()]
263
  prompt = f"""
264
+ Write a {settings['article_length']} words (around) {settings['format']} on {settings['topic']}.\n
265
+ """
266
+ if settings['context']:
267
+ prompt += f"""
268
  Context:
269
  - {settings['context']}
270
+ """
271
+ prompt += f"""
272
  Style and Tone:
273
  - Writing style: {settings['writing_style']}
274
  - Tone: {settings['tone']}
 
277
  Content:
278
  - Depth: {settings['depth_of_content']}
279
  - Structure: {', '.join(settings['structure'])}
280
+ """
281
+ if len(settings['keywords']) > 0:
282
+ prompt += f"""
283
  Keywords to incorporate:
284
  {', '.join(settings['keywords'])}
285
+ """
286
+ prompt += f"""
287
  Additional requirements:
288
  - Don't start with "Here is a...", start with the requested text directly
 
 
289
  - End with a {settings['conclusion_type']} conclusion
 
 
290
  - Do not make any headline, title bold.
291
+ - Ensure proper paragraph breaks for better readability.
292
+ - Avoid any references to artificial intelligence, language models, or the fact that this is generated by an AI, and do not mention something like here is the article etc.
293
+ - Adhere to any format structure provided to the system if any.
294
  """
295
  return prompt
296
 
 
364
  prompt = generate_prompt(settings)
365
 
366
  print("Generated Prompt...\n", prompt)
367
+ article, citations = generate(
368
  prompt=prompt,
369
  topic=topic,
370
  model=ai_model,
371
  url_content=url_content,
372
  path=pdf_file_input,
373
+ # path=["./final_report.pdf"], # TODO: reset
374
  temperature=1,
375
  max_length=2048,
376
  api_key=api_key,
377
  sys_message="",
378
  )
379
+ return clean_text(article), citations_to_html(citations)
 
380
 
381
 
382
  def get_history(history):
 
574
  print(f"Google Search Query: {final_query}")
575
  url_content = google_search(final_query, sorted_date, domains_to_include)
576
  topic_context = topic + ", " + context
577
+ article, citations = generate_article(
578
  input_role,
579
  topic_context,
580
  context,
 
632
  )
633
  print(save_message)
634
 
635
+ return reference_formatted, citations, history
636
 
637
 
638
  def create_interface():
 
860
  with gr.Column(scale=3):
861
  with gr.Tab("Text Generator"):
862
  output_article = gr.Textbox(label="Generated Article", lines=20)
863
+ with gr.Accordion("Citations", open=True):
864
+ output_citations = gr.HTML(label="Citations")
865
  ai_comments = gr.Textbox(
866
  label="Add comments to help edit generated text", interactive=True, visible=False
867
  )
 
971
  pdf_file_input,
972
  history,
973
  ],
974
+ outputs=[output_article, output_citations, history],
975
  )
976
 
977
  regenerate_btn.click(
 
1008
  exclude_sites,
1009
  ai_comments,
1010
  ],
1011
+ outputs=[output_article, output_citations, history],
1012
  )
1013
 
1014
  ai_check_btn.click(
 
1040
 
1041
  if __name__ == "__main__":
1042
  demo = create_interface()
1043
+ # demo.queue(
1044
+ # max_size=2,
1045
+ # default_concurrency_limit=2,
1046
+ # ).launch(server_name="0.0.0.0", share=True, server_port=7890)
1047
+ demo.launch(server_name="0.0.0.0")