Spaces:
Runtime error
Runtime error
add inline citations + page content
Browse files- .gitignore +1 -0
- ai_generate.py +109 -6
- app.py +30 -25
.gitignore
CHANGED
@@ -4,4 +4,5 @@ nohup.out
|
|
4 |
*.out
|
5 |
*.log
|
6 |
*.json
|
|
|
7 |
temp.py
|
|
|
4 |
*.out
|
5 |
*.log
|
6 |
*.json
|
7 |
+
*.pdf
|
8 |
temp.py
|
ai_generate.py
CHANGED
@@ -15,6 +15,8 @@ from langchain_openai import ChatOpenAI
|
|
15 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
16 |
from langchain_anthropic import ChatAnthropic
|
17 |
from dotenv import load_dotenv
|
|
|
|
|
18 |
|
19 |
load_dotenv()
|
20 |
|
@@ -47,6 +49,99 @@ llm_classes = {
|
|
47 |
}
|
48 |
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
|
51 |
model_name = llm_model_translation.get(model)
|
52 |
llm_class = llm_classes.get(model_name)
|
@@ -108,15 +203,23 @@ def generate_rag(
|
|
108 |
rag_prompt = hub.pull("rlm/rag-prompt")
|
109 |
|
110 |
def format_docs(docs):
|
111 |
-
|
|
|
|
|
|
|
112 |
|
113 |
docs = retriever.get_relevant_documents(topic)
|
114 |
-
|
|
|
115 |
rag_chain = (
|
116 |
-
|
|
|
|
|
|
|
117 |
)
|
118 |
-
|
119 |
-
|
|
|
120 |
|
121 |
def generate_base(
|
122 |
prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
|
@@ -147,4 +250,4 @@ def generate(
|
|
147 |
if path or url_content:
|
148 |
return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
|
149 |
else:
|
150 |
-
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
|
|
|
15 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
16 |
from langchain_anthropic import ChatAnthropic
|
17 |
from dotenv import load_dotenv
|
18 |
+
from langchain_core.output_parsers import XMLOutputParser
|
19 |
+
from langchain.prompts import ChatPromptTemplate
|
20 |
|
21 |
load_dotenv()
|
22 |
|
|
|
49 |
}
|
50 |
|
51 |
|
52 |
+
xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, fulfill all the requirements \
|
53 |
+
of the prompt and provide citations. If a chunk of the generated text does not use any of the sources (for example, \
|
54 |
+
introductions or general text), don't put a citation for that chunk and just leave citations empty. Otherwise, \
|
55 |
+
list all sources used for that chunk of the text. Don't add inline citations in the text itself. Add all citations to the separated \
|
56 |
+
citations section. Use explicit new lines in the text to show paragraph splits. \
|
57 |
+
Return a citation for every quote across all articles that justify the text. Use the following format for your final output:
|
58 |
+
<cited_text>
|
59 |
+
<chunk>
|
60 |
+
<text></text>
|
61 |
+
<citations>
|
62 |
+
<citation><source_id></source_id></citation>
|
63 |
+
...
|
64 |
+
</citations>
|
65 |
+
</chunk>
|
66 |
+
<chunk>
|
67 |
+
<text></text>
|
68 |
+
<citations>
|
69 |
+
<citation><source_id></source_id></citation>
|
70 |
+
...
|
71 |
+
</citations>
|
72 |
+
</chunk>
|
73 |
+
...
|
74 |
+
</cited_text>
|
75 |
+
The entire text should be wrapped in one cited_text. For References section (if asked by prompt), don't add citations.
|
76 |
+
For source id, give a valid integer alone without a key.
|
77 |
+
Here are the sources:{context}"""
|
78 |
+
xml_prompt = ChatPromptTemplate.from_messages(
|
79 |
+
[("system", xml_system), ("human", "{input}")]
|
80 |
+
)
|
81 |
+
|
82 |
+
def format_docs_xml(docs: list[Document]) -> str:
|
83 |
+
formatted = []
|
84 |
+
for i, doc in enumerate(docs):
|
85 |
+
doc_str = f"""\
|
86 |
+
<source id=\"{i}\">
|
87 |
+
<path>{doc.metadata['source']}</path>
|
88 |
+
<article_snippet>{doc.page_content}</article_snippet>
|
89 |
+
</source>"""
|
90 |
+
formatted.append(doc_str)
|
91 |
+
return "\n\n<sources>" + "\n".join(formatted) + "</sources>"
|
92 |
+
|
93 |
+
|
94 |
+
def get_doc_content(docs, id):
|
95 |
+
return docs[id].page_content
|
96 |
+
|
97 |
+
|
98 |
+
def process_cited_text(data, docs):
|
99 |
+
# Initialize variables for the combined text and a dictionary for citations
|
100 |
+
combined_text = ""
|
101 |
+
citations = {}
|
102 |
+
# Iterate through the cited_text list
|
103 |
+
for item in data['cited_text']:
|
104 |
+
chunk_text = item['chunk'][0]['text']
|
105 |
+
combined_text += chunk_text
|
106 |
+
citation_ids = []
|
107 |
+
# Process the citations for the chunk
|
108 |
+
if item['chunk'][1]['citations']:
|
109 |
+
for c in item['chunk'][1]['citations']:
|
110 |
+
if c and 'citation' in c:
|
111 |
+
citation = c['citation']
|
112 |
+
if isinstance(citation, dict) and "source_id" in citation:
|
113 |
+
citation = citation['source_id']
|
114 |
+
if isinstance(citation, str):
|
115 |
+
try:
|
116 |
+
citation_ids.append(int(citation))
|
117 |
+
except ValueError:
|
118 |
+
pass # Handle cases where the string is not a valid integer
|
119 |
+
if citation_ids:
|
120 |
+
citation_texts = [f"<{cid}-{docs[cid].metadata['source']}>" for cid in citation_ids]
|
121 |
+
combined_text += " " + " ".join(citation_texts)
|
122 |
+
combined_text += "\n\n"
|
123 |
+
# Store unique citations in a dictionary
|
124 |
+
for citation_id in citation_ids:
|
125 |
+
if citation_id not in citations:
|
126 |
+
citations[citation_id] = {'source': docs[citation_id].metadata['source'], 'content': docs[citation_id].page_content}
|
127 |
+
|
128 |
+
return combined_text.strip(), citations
|
129 |
+
|
130 |
+
|
131 |
+
def citations_to_html(citations):
|
132 |
+
# Generate the HTML for the unique citations
|
133 |
+
html_content = ""
|
134 |
+
for citation_id, citation_info in citations.items():
|
135 |
+
html_content += (
|
136 |
+
f"<li><strong>Source ID:</strong> {citation_id}<br>"
|
137 |
+
f"<strong>Path:</strong> {citation_info['source']}<br>"
|
138 |
+
f"<strong>Page Content:</strong> {citation_info['content']}</li>"
|
139 |
+
)
|
140 |
+
html_content += "</ul></body></html>"
|
141 |
+
|
142 |
+
return html_content
|
143 |
+
|
144 |
+
|
145 |
def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
|
146 |
model_name = llm_model_translation.get(model)
|
147 |
llm_class = llm_classes.get(model_name)
|
|
|
203 |
rag_prompt = hub.pull("rlm/rag-prompt")
|
204 |
|
205 |
def format_docs(docs):
|
206 |
+
if all(isinstance(doc, Document) for doc in docs):
|
207 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
208 |
+
else:
|
209 |
+
raise TypeError("All items in docs must be instances of Document.")
|
210 |
|
211 |
docs = retriever.get_relevant_documents(topic)
|
212 |
+
|
213 |
+
formatted_docs = format_docs_xml(docs)
|
214 |
rag_chain = (
|
215 |
+
RunnablePassthrough.assign(context=lambda _: formatted_docs)
|
216 |
+
| xml_prompt
|
217 |
+
| llm
|
218 |
+
| XMLOutputParser()
|
219 |
)
|
220 |
+
result = rag_chain.invoke({"input": prompt})
|
221 |
+
text, citations = process_cited_text(result, docs)
|
222 |
+
return text, citations
|
223 |
|
224 |
def generate_base(
|
225 |
prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
|
|
|
250 |
if path or url_content:
|
251 |
return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
|
252 |
else:
|
253 |
+
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
|
app.py
CHANGED
@@ -19,7 +19,7 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipe
|
|
19 |
from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
|
20 |
from google_search import google_search, months, domain_list, build_date
|
21 |
from humanize import humanize_text, device
|
22 |
-
from ai_generate import generate
|
23 |
|
24 |
print(f"Using device: {device}")
|
25 |
|
@@ -259,12 +259,16 @@ def ai_check(text: str, option: str):
|
|
259 |
|
260 |
|
261 |
def generate_prompt(settings: Dict[str, str]) -> str:
|
|
|
262 |
prompt = f"""
|
263 |
-
|
264 |
-
|
|
|
|
|
265 |
Context:
|
266 |
- {settings['context']}
|
267 |
-
|
|
|
268 |
Style and Tone:
|
269 |
- Writing style: {settings['writing_style']}
|
270 |
- Tone: {settings['tone']}
|
@@ -273,21 +277,20 @@ def generate_prompt(settings: Dict[str, str]) -> str:
|
|
273 |
Content:
|
274 |
- Depth: {settings['depth_of_content']}
|
275 |
- Structure: {', '.join(settings['structure'])}
|
276 |
-
|
|
|
|
|
277 |
Keywords to incorporate:
|
278 |
{', '.join(settings['keywords'])}
|
279 |
-
|
|
|
280 |
Additional requirements:
|
281 |
- Don't start with "Here is a...", start with the requested text directly
|
282 |
-
- Include {settings['num_examples']} relevant examples or case studies
|
283 |
-
- Incorporate data or statistics from {', '.join(settings['references'])}
|
284 |
- End with a {settings['conclusion_type']} conclusion
|
285 |
-
- Add a "References" section in the format "References:" on a new line at the end with at least 3 credible detailed sources, formatted as [1], [2], etc. with each source on their own line
|
286 |
-
- Do not repeat sources
|
287 |
- Do not make any headline, title bold.
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
"""
|
292 |
return prompt
|
293 |
|
@@ -361,19 +364,19 @@ def generate_article(
|
|
361 |
prompt = generate_prompt(settings)
|
362 |
|
363 |
print("Generated Prompt...\n", prompt)
|
364 |
-
article = generate(
|
365 |
prompt=prompt,
|
366 |
topic=topic,
|
367 |
model=ai_model,
|
368 |
url_content=url_content,
|
369 |
path=pdf_file_input,
|
|
|
370 |
temperature=1,
|
371 |
max_length=2048,
|
372 |
api_key=api_key,
|
373 |
sys_message="",
|
374 |
)
|
375 |
-
|
376 |
-
return clean_text(article)
|
377 |
|
378 |
|
379 |
def get_history(history):
|
@@ -571,7 +574,7 @@ def generate_and_format(
|
|
571 |
print(f"Google Search Query: {final_query}")
|
572 |
url_content = google_search(final_query, sorted_date, domains_to_include)
|
573 |
topic_context = topic + ", " + context
|
574 |
-
article = generate_article(
|
575 |
input_role,
|
576 |
topic_context,
|
577 |
context,
|
@@ -629,7 +632,7 @@ def generate_and_format(
|
|
629 |
)
|
630 |
print(save_message)
|
631 |
|
632 |
-
return reference_formatted, history
|
633 |
|
634 |
|
635 |
def create_interface():
|
@@ -857,6 +860,8 @@ def create_interface():
|
|
857 |
with gr.Column(scale=3):
|
858 |
with gr.Tab("Text Generator"):
|
859 |
output_article = gr.Textbox(label="Generated Article", lines=20)
|
|
|
|
|
860 |
ai_comments = gr.Textbox(
|
861 |
label="Add comments to help edit generated text", interactive=True, visible=False
|
862 |
)
|
@@ -966,7 +971,7 @@ def create_interface():
|
|
966 |
pdf_file_input,
|
967 |
history,
|
968 |
],
|
969 |
-
outputs=[output_article, history],
|
970 |
)
|
971 |
|
972 |
regenerate_btn.click(
|
@@ -1003,7 +1008,7 @@ def create_interface():
|
|
1003 |
exclude_sites,
|
1004 |
ai_comments,
|
1005 |
],
|
1006 |
-
outputs=[output_article, history],
|
1007 |
)
|
1008 |
|
1009 |
ai_check_btn.click(
|
@@ -1035,8 +1040,8 @@ def create_interface():
|
|
1035 |
|
1036 |
if __name__ == "__main__":
|
1037 |
demo = create_interface()
|
1038 |
-
demo.queue(
|
1039 |
-
|
1040 |
-
|
1041 |
-
).launch(server_name="0.0.0.0", share=True, server_port=7890)
|
1042 |
-
|
|
|
19 |
from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
|
20 |
from google_search import google_search, months, domain_list, build_date
|
21 |
from humanize import humanize_text, device
|
22 |
+
from ai_generate import generate, citations_to_html
|
23 |
|
24 |
print(f"Using device: {device}")
|
25 |
|
|
|
259 |
|
260 |
|
261 |
def generate_prompt(settings: Dict[str, str]) -> str:
|
262 |
+
settings['keywords'] = [item for item in settings['keywords'] if item.strip()]
|
263 |
prompt = f"""
|
264 |
+
Write a {settings['article_length']} words (around) {settings['format']} on {settings['topic']}.\n
|
265 |
+
"""
|
266 |
+
if settings['context']:
|
267 |
+
prompt += f"""
|
268 |
Context:
|
269 |
- {settings['context']}
|
270 |
+
"""
|
271 |
+
prompt += f"""
|
272 |
Style and Tone:
|
273 |
- Writing style: {settings['writing_style']}
|
274 |
- Tone: {settings['tone']}
|
|
|
277 |
Content:
|
278 |
- Depth: {settings['depth_of_content']}
|
279 |
- Structure: {', '.join(settings['structure'])}
|
280 |
+
"""
|
281 |
+
if len(settings['keywords']) > 0:
|
282 |
+
prompt += f"""
|
283 |
Keywords to incorporate:
|
284 |
{', '.join(settings['keywords'])}
|
285 |
+
"""
|
286 |
+
prompt += f"""
|
287 |
Additional requirements:
|
288 |
- Don't start with "Here is a...", start with the requested text directly
|
|
|
|
|
289 |
- End with a {settings['conclusion_type']} conclusion
|
|
|
|
|
290 |
- Do not make any headline, title bold.
|
291 |
+
- Ensure proper paragraph breaks for better readability.
|
292 |
+
- Avoid any references to artificial intelligence, language models, or the fact that this is generated by an AI, and do not mention something like here is the article etc.
|
293 |
+
- Adhere to any format structure provided to the system if any.
|
294 |
"""
|
295 |
return prompt
|
296 |
|
|
|
364 |
prompt = generate_prompt(settings)
|
365 |
|
366 |
print("Generated Prompt...\n", prompt)
|
367 |
+
article, citations = generate(
|
368 |
prompt=prompt,
|
369 |
topic=topic,
|
370 |
model=ai_model,
|
371 |
url_content=url_content,
|
372 |
path=pdf_file_input,
|
373 |
+
# path=["./final_report.pdf"], # TODO: reset
|
374 |
temperature=1,
|
375 |
max_length=2048,
|
376 |
api_key=api_key,
|
377 |
sys_message="",
|
378 |
)
|
379 |
+
return clean_text(article), citations_to_html(citations)
|
|
|
380 |
|
381 |
|
382 |
def get_history(history):
|
|
|
574 |
print(f"Google Search Query: {final_query}")
|
575 |
url_content = google_search(final_query, sorted_date, domains_to_include)
|
576 |
topic_context = topic + ", " + context
|
577 |
+
article, citations = generate_article(
|
578 |
input_role,
|
579 |
topic_context,
|
580 |
context,
|
|
|
632 |
)
|
633 |
print(save_message)
|
634 |
|
635 |
+
return reference_formatted, citations, history
|
636 |
|
637 |
|
638 |
def create_interface():
|
|
|
860 |
with gr.Column(scale=3):
|
861 |
with gr.Tab("Text Generator"):
|
862 |
output_article = gr.Textbox(label="Generated Article", lines=20)
|
863 |
+
with gr.Accordion("Citations", open=True):
|
864 |
+
output_citations = gr.HTML(label="Citations")
|
865 |
ai_comments = gr.Textbox(
|
866 |
label="Add comments to help edit generated text", interactive=True, visible=False
|
867 |
)
|
|
|
971 |
pdf_file_input,
|
972 |
history,
|
973 |
],
|
974 |
+
outputs=[output_article, output_citations, history],
|
975 |
)
|
976 |
|
977 |
regenerate_btn.click(
|
|
|
1008 |
exclude_sites,
|
1009 |
ai_comments,
|
1010 |
],
|
1011 |
+
outputs=[output_article, output_citations, history],
|
1012 |
)
|
1013 |
|
1014 |
ai_check_btn.click(
|
|
|
1040 |
|
1041 |
if __name__ == "__main__":
|
1042 |
demo = create_interface()
|
1043 |
+
# demo.queue(
|
1044 |
+
# max_size=2,
|
1045 |
+
# default_concurrency_limit=2,
|
1046 |
+
# ).launch(server_name="0.0.0.0", share=True, server_port=7890)
|
1047 |
+
demo.launch(server_name="0.0.0.0")
|