|
import json |
|
from collections import defaultdict |
|
import openai |
|
import re |
|
from config import CFG_APP |
|
from text_embedder import SentenceTransformersTextEmbedder |
|
from datetime import datetime |
|
import tiktoken |
|
|
|
doc_metadata = json.load(open(CFG_APP.DOC_METADATA_PATH, "r")) |
|
|
|
if "sentence-transformers" in CFG_APP.EMBEDDING_MODEL: |
|
text_embedder = SentenceTransformersTextEmbedder( |
|
model_name=CFG_APP.EMBEDDING_MODEL, |
|
paragraphs_path=CFG_APP.DATA_FOLDER, |
|
device=CFG_APP.DEVICE, |
|
load_existing_index=True, |
|
) |
|
else: |
|
raise ValueError("Embedding model not found !") |
|
|
|
|
|
|
|
def retrieve_doc_metadata(doc_metadata, doc_id): |
|
for meta in doc_metadata: |
|
if meta["id"] == doc_id: |
|
return meta |
|
|
|
|
|
def get_reformulation_prompt(query: str) -> list: |
|
return [ |
|
{ |
|
"role": "user", |
|
"content": f"""{CFG_APP.REFORMULATION_PROMPT} |
|
--- |
|
query: {query} |
|
standalone question: """, |
|
} |
|
] |
|
|
|
def get_hyde_prompt(query: str) -> list: |
|
return [ |
|
{ |
|
"role": "user", |
|
"content": f"""{CFG_APP.HYDE_PROMPT} |
|
--- |
|
query: {query} |
|
output: """, |
|
} |
|
] |
|
|
|
|
|
def make_pairs(lst): |
|
"""From a list of even lenght, make tupple pairs |
|
Args: |
|
lst (list): a list of even lenght |
|
Returns: |
|
list: the list as tupple pairs |
|
""" |
|
assert not (l := len(lst) % 2), f"your list is of lenght {l} which is not even" |
|
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)] |
|
|
|
|
|
def make_html_source(paragraph, meta_doc, i): |
|
content = paragraph["content"] |
|
meta_paragraph = paragraph["meta"] |
|
return f""" |
|
<div class="card" id="document-{i}"> |
|
<div class="card-content"> |
|
<h2>Excerpts {i} - Document {meta_doc['num_doc']} - Page {meta_paragraph['page_number']}</h2> |
|
<p>{content}</p> |
|
</div> |
|
<div class="card-footer"> |
|
<span>{meta_doc['short_name']}</span> |
|
<a href="{meta_doc['url']}#page={meta_paragraph['page_number']}" target="_blank" class="pdf-link"> |
|
<span role="img" aria-label="Open PDF">π</span> |
|
</a> |
|
</div> |
|
</div> |
|
""" |
|
|
|
def make_citations_source(citation_dic, query, Hyde: False): |
|
citation_list = [f'Doc {values[0]} - {keys} (excerpts {values[1]})' for keys, values in citation_dic.items()] |
|
|
|
html_output = '<div class="source">\n' |
|
html_output += ' <div class="title">Sources</div>\n' |
|
if Hyde : |
|
html_output += f' <div>Query used for retrieval (with the HyDE technique after no response): {query}</div>\n' |
|
else : |
|
html_output += f' <div>Query used for retrieval: {query}</div>\n' |
|
html_output += ' <br>\n' |
|
html_output += ' <ul>\n' |
|
|
|
for row in citation_list : |
|
html_output += f'<li>{row}</li>' |
|
|
|
html_output += ' </ul>\n' |
|
html_output += '</div>\n' |
|
|
|
return html_output |
|
|
|
|
|
def preprocess_message(text: str, docs_url: dict) -> str: |
|
return re.sub( |
|
r"\[doc (\d+)\]", |
|
lambda match: f'<a href="{docs_url[match.group(1)]}" target="_blank" class="pdf-link">{match.group(0)}</a>', |
|
text, |
|
) |
|
|
|
|
|
def parse_glossary(query): |
|
file = "glossary.json" |
|
glossary = json.load(open(file, "r")) |
|
words_query = query.split(" ") |
|
for i, word in enumerate(words_query): |
|
for key in glossary.keys(): |
|
if word.lower() == key.lower(): |
|
words_query[i] = words_query[i] + f" ({glossary[key]})" |
|
return " ".join(words_query) |
|
|
|
|
|
def num_tokens_from_string(string: str, encoding_name: str) -> int: |
|
encoding = tiktoken.encoding_for_model(encoding_name) |
|
num_tokens = len(encoding.encode(string)) |
|
return num_tokens |
|
|
|
|
|
def chat( |
|
query: str, |
|
history: list, |
|
threshold: float = CFG_APP.THRESHOLD, |
|
k_total: int = CFG_APP.K_TOTAL, |
|
) -> tuple: |
|
"""retrieve relevant documents in the document store then query gpt-turbo |
|
Args: |
|
query (str): user message. |
|
history (list, optional): history of the conversation. Defaults to [system_template]. |
|
report_type (str, optional): should be "All available" or "IPCC only". Defaults to "All available". |
|
threshold (float, optional): similarity threshold, don't increase more than 0.568. Defaults to 0.56. |
|
Yields: |
|
tuple: chat gradio format, chat openai format, sources used. |
|
""" |
|
|
|
reformulated_query = openai.ChatCompletion.create( |
|
model=CFG_APP.MODEL_NAME, |
|
messages=get_reformulation_prompt(parse_glossary(query)), |
|
temperature=0, |
|
max_tokens=CFG_APP.MAX_TOKENS_REF_QUESTION, |
|
) |
|
|
|
reformulated_query = reformulated_query["choices"][0]["message"]["content"] |
|
|
|
if len(reformulated_query.split("\n")) == 2: |
|
reformulated_query, language = reformulated_query.split("\n") |
|
language = language.split(":")[1].strip() |
|
else: |
|
reformulated_query = reformulated_query.split("\n")[0] |
|
language = "English" |
|
|
|
sources, scores = text_embedder.retrieve_faiss( |
|
reformulated_query, |
|
k_total=k_total, |
|
threshold=threshold, |
|
) |
|
|
|
if CFG_APP.DEBUG == True: |
|
print("Scores : \n", scores) |
|
|
|
messages = history + [{"role": "user", "content": query}] |
|
|
|
docs_url = defaultdict(str) |
|
|
|
if len(sources) > 0: |
|
docs_string = [] |
|
docs_html = [] |
|
citations = {} |
|
|
|
num_tokens = num_tokens_from_string(CFG_APP.SOURCES_PROMPT, CFG_APP.MODEL_NAME) |
|
num_doc = 1 |
|
|
|
for i, data in enumerate(sources, 1): |
|
meta_doc = retrieve_doc_metadata(doc_metadata, data["meta"]["document_id"]) |
|
doc_content = f"π Doc {i}: \n{data['content']}" |
|
num_tokens_doc = num_tokens_from_string(doc_content, CFG_APP.MODEL_NAME) |
|
if num_tokens + num_tokens_doc > CFG_APP.MAX_TOKENS_API: |
|
break |
|
num_tokens += num_tokens_doc |
|
docs_string.append(doc_content) |
|
|
|
if meta_doc['short_name'] in citations.keys(): |
|
citations[meta_doc['short_name']][1] += f', {i}' |
|
else : |
|
citations[meta_doc['short_name']] = [num_doc, f'{i}'] |
|
num_doc += 1 |
|
|
|
meta_doc["num_doc"] = citations[meta_doc['short_name']][0] |
|
|
|
docs_html.append(make_html_source(data, meta_doc, i)) |
|
|
|
url_doc = f'<a href="{meta_doc["url"]}#page={data["meta"]["page_number"]}" target="_blank" class="pdf-link">' |
|
docs_url[i] = url_doc |
|
|
|
html_cit = [make_citations_source(citations, reformulated_query, Hyde=False)] |
|
|
|
docs_string = "\n\n".join( [f"Query used for retrieval:\n{reformulated_query}"] + docs_string) |
|
|
|
docs_html = "\n\n".join(html_cit + docs_html) |
|
|
|
messages.append( |
|
{ |
|
"role": "system", |
|
"content": f"{CFG_APP.SOURCES_PROMPT}\n\n{docs_string}\n\nAnswer in {language}:", |
|
} |
|
) |
|
|
|
if CFG_APP.DEBUG == True: |
|
print(f" π¨βπ» question asked by the user : {query}") |
|
print(f" π time : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
print(" π messages sent to the API :") |
|
api_messages = [ |
|
{"role": "system", "content": CFG_APP.INIT_PROMPT}, |
|
{"role": "user", "content": reformulated_query}, |
|
{ |
|
"role": "system", |
|
"content": f"{CFG_APP.SOURCES_PROMPT}\n\n{docs_string}\n\nAnswer in {language}:", |
|
}, |
|
] |
|
for message in api_messages: |
|
print( |
|
f"length : {len(message['content'])}, content : {message['content']}" |
|
) |
|
|
|
response = openai.ChatCompletion.create( |
|
model=CFG_APP.MODEL_NAME, |
|
messages=[ |
|
{"role": "system", "content": CFG_APP.INIT_PROMPT}, |
|
{"role": "user", "content": reformulated_query}, |
|
{ |
|
"role": "system", |
|
"content": f"{CFG_APP.SOURCES_PROMPT}\n\nVery important : Answer in {language}.\n\n{docs_string}:", |
|
}, |
|
], |
|
temperature=0, |
|
stream=True, |
|
max_tokens=CFG_APP.MAX_TOKENS_ANSWER, |
|
) |
|
complete_response = "" |
|
messages.pop() |
|
messages.append({"role": "assistant", "content": complete_response}) |
|
for chunk in response: |
|
chunk_message = chunk["choices"][0]["delta"].get("content") |
|
if chunk_message: |
|
complete_response += chunk_message |
|
complete_response = preprocess_message(complete_response, docs_url) |
|
messages[-1]["content"] = complete_response |
|
gradio_format = make_pairs([a["content"] for a in messages[1:]]) |
|
yield gradio_format, messages, docs_html |
|
|
|
else: |
|
reformulated_query = openai.ChatCompletion.create( |
|
model=CFG_APP.MODEL_NAME, |
|
messages=get_hyde_prompt(parse_glossary(query)), |
|
temperature=0, |
|
max_tokens=CFG_APP.MAX_TOKENS_REF_QUESTION, |
|
) |
|
|
|
reformulated_query = reformulated_query["choices"][0]["message"]["content"] |
|
|
|
if len(reformulated_query.split("\n")) == 2: |
|
reformulated_query, language = reformulated_query.split("\n") |
|
language = language.split(":")[1].strip() |
|
else: |
|
reformulated_query = reformulated_query.split("\n")[0] |
|
language = "English" |
|
|
|
sources, scores = text_embedder.retrieve_faiss( |
|
reformulated_query, |
|
k_total=k_total, |
|
threshold=threshold, |
|
) |
|
|
|
if CFG_APP.DEBUG == True: |
|
print("Scores : \n", scores) |
|
|
|
if len(sources) > 0 : |
|
docs_string = [] |
|
docs_html = [] |
|
citations = {} |
|
|
|
num_tokens = num_tokens_from_string(CFG_APP.SOURCES_PROMPT, CFG_APP.MODEL_NAME) |
|
|
|
num_doc = 1 |
|
|
|
for i, data in enumerate(sources, 1): |
|
meta_doc = retrieve_doc_metadata(doc_metadata, data["meta"]["document_id"]) |
|
doc_content = f"π Doc {i}: \n{data['content']}" |
|
num_tokens_doc = num_tokens_from_string(doc_content, CFG_APP.MODEL_NAME) |
|
if num_tokens + num_tokens_doc > CFG_APP.MAX_TOKENS_API: |
|
break |
|
num_tokens += num_tokens_doc |
|
docs_string.append(doc_content) |
|
|
|
if meta_doc['short_name'] in citations.keys(): |
|
citations[meta_doc['short_name']][1] += f', {i}' |
|
else: |
|
citations[meta_doc['short_name']] = [num_doc, f'{i}'] |
|
num_doc += 1 |
|
|
|
meta_doc["num_doc"] = citations[meta_doc['short_name']][0] |
|
|
|
docs_html.append(make_html_source(data, meta_doc, i)) |
|
|
|
url_doc = f'<a href="{meta_doc["url"]}#page={data["meta"]["page_number"]}" target="_blank" class="pdf-link">' |
|
docs_url[i] = url_doc |
|
|
|
html_cit = [make_citations_source(citations, reformulated_query, Hyde=True)] |
|
|
|
docs_string = "\n\n".join([f"Query used for retrieval:\n{reformulated_query}"] + docs_string) |
|
|
|
docs_html = "\n\n".join(html_cit + docs_html) |
|
|
|
messages.append( |
|
{ |
|
"role": "system", |
|
"content": f"{CFG_APP.SOURCES_PROMPT}\n\n{docs_string}\n\nAnswer in {language}:", |
|
} |
|
) |
|
|
|
if CFG_APP.DEBUG == True: |
|
print(f" π¨βπ» question asked by the user : {query}") |
|
print(f" π time : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
print(" π messages sent to the API :") |
|
api_messages = [ |
|
{"role": "system", "content": CFG_APP.INIT_PROMPT}, |
|
{"role": "user", "content": reformulated_query}, |
|
{ |
|
"role": "system", |
|
"content": f"{CFG_APP.SOURCES_PROMPT}\n\nVery important : Answer in {language}.\n\n{docs_string}:", |
|
}, |
|
] |
|
for message in api_messages: |
|
print( |
|
f"length : {len(message['content'])}, content : {message['content']}" |
|
) |
|
|
|
response = openai.ChatCompletion.create( |
|
model=CFG_APP.MODEL_NAME, |
|
messages=[ |
|
{"role": "system", "content": CFG_APP.INIT_PROMPT}, |
|
{"role": "user", "content": reformulated_query}, |
|
{ |
|
"role": "system", |
|
"content": f"{CFG_APP.SOURCES_PROMPT}\n\nVery important : Answer in {language}.\n\n{docs_string}:", |
|
}, |
|
], |
|
temperature=0, |
|
stream=True, |
|
max_tokens=CFG_APP.MAX_TOKENS_ANSWER, |
|
) |
|
complete_response = "" |
|
messages.pop() |
|
messages.append({"role": "assistant", "content": complete_response}) |
|
for chunk in response: |
|
chunk_message = chunk["choices"][0]["delta"].get("content") |
|
if chunk_message: |
|
complete_response += chunk_message |
|
complete_response = preprocess_message(complete_response, docs_url) |
|
messages[-1]["content"] = complete_response |
|
gradio_format = make_pairs([a["content"] for a in messages[1:]]) |
|
yield gradio_format, messages, docs_html |
|
|
|
else : |
|
docs_string = "β οΈ No relevant passages found in this report" |
|
complete_response = "**β οΈ No relevant passages found in this report, you may want to ask a more specific question.**" |
|
messages.append({"role": "assistant", "content": complete_response}) |
|
gradio_format = make_pairs([a["content"] for a in messages[1:]]) |
|
yield gradio_format, messages, docs_string |
|
|
|
|
|
|