|
import re |
|
import os |
|
from typing import TypeVar, List |
|
import pandas as pd |
|
|
|
|
|
|
|
import torch.cuda |
|
|
|
|
|
|
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.docstore.document import Document |
|
|
|
|
|
|
|
|
|
from nltk.corpus import stopwords |
|
from nltk.tokenize import RegexpTokenizer |
|
from nltk.stem import WordNetLemmatizer |
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
torch.cuda.empty_cache() |
|
|
|
PandasDataFrame = TypeVar('pd.core.frame.DataFrame') |
|
|
|
embeddings = None |
|
vectorstore = None |
|
model_type = None |
|
|
|
max_memory_length = 0 |
|
|
|
full_text = "" |
|
|
|
model = [] |
|
tokenizer = [] |
|
|
|
|
|
hlt_chunk_size = 12 |
|
hlt_strat = [" ", ". ", "! ", "? ", ": ", "\n\n", "\n", ", "] |
|
hlt_overlap = 4 |
|
|
|
|
|
ner_model = [] |
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch_device = "cuda" |
|
gpu_layers = 0 |
|
else: |
|
torch_device = "cpu" |
|
gpu_layers = 0 |
|
|
|
print("Running on device:", torch_device) |
|
threads = 6 |
|
print("CPU threads:", threads) |
|
|
|
|
|
|
|
|
|
|
|
def write_out_metadata_as_string(metadata_in): |
|
metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] |
|
return metadata_string |
|
|
|
|
|
def determine_file_type(file_path): |
|
""" |
|
Determine the file type based on its extension. |
|
|
|
Parameters: |
|
file_path (str): Path to the file. |
|
|
|
Returns: |
|
str: File extension (e.g., '.pdf', '.docx', '.txt', '.html'). |
|
""" |
|
return os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
def create_doc_df(docs_keep_out): |
|
|
|
content=[] |
|
meta=[] |
|
meta_url=[] |
|
page_section=[] |
|
score=[] |
|
|
|
doc_df = pd.DataFrame() |
|
|
|
|
|
|
|
for item in docs_keep_out: |
|
content.append(item[0].page_content) |
|
meta.append(item[0].metadata) |
|
meta_url.append(item[0].metadata['source']) |
|
|
|
file_extension = determine_file_type(item[0].metadata['source']) |
|
if (file_extension != ".csv") & (file_extension != ".xlsx"): |
|
page_section.append(item[0].metadata['page_section']) |
|
else: page_section.append("") |
|
score.append(item[1]) |
|
|
|
|
|
|
|
doc_df = pd.DataFrame(list(zip(content, meta, page_section, meta_url, score)), |
|
columns =['page_content', 'metadata', 'page_section', 'meta_url', 'score']) |
|
|
|
docs_content = doc_df['page_content'].astype(str) |
|
doc_df['full_url'] = "https://" + doc_df['meta_url'] |
|
|
|
return doc_df |
|
|
|
|
|
def get_expanded_passages(vectorstore, docs, width): |
|
|
|
""" |
|
Extracts expanded passages based on given documents and a width for context. |
|
|
|
Parameters: |
|
- vectorstore: The primary data source. |
|
- docs: List of documents to be expanded. |
|
- width: Number of documents to expand around a given document for context. |
|
|
|
Returns: |
|
- expanded_docs: List of expanded Document objects. |
|
- doc_df: DataFrame representation of expanded_docs. |
|
""" |
|
|
|
from collections import defaultdict |
|
|
|
def get_docs_from_vstore(vectorstore): |
|
vector = vectorstore.docstore._dict |
|
return list(vector.items()) |
|
|
|
def extract_details(docs_list): |
|
docs_list_out = [tup[1] for tup in docs_list] |
|
content = [doc.page_content for doc in docs_list_out] |
|
meta = [doc.metadata for doc in docs_list_out] |
|
return ''.join(content), meta[0], meta[-1] |
|
|
|
def get_parent_content_and_meta(vstore_docs, width, target): |
|
|
|
target_range = range(max(0, target), min(len(vstore_docs), target + width + 1)) |
|
parent_vstore_out = [vstore_docs[i] for i in target_range] |
|
|
|
content_str_out, meta_first_out, meta_last_out = [], [], [] |
|
for _ in parent_vstore_out: |
|
content_str, meta_first, meta_last = extract_details(parent_vstore_out) |
|
content_str_out.append(content_str) |
|
meta_first_out.append(meta_first) |
|
meta_last_out.append(meta_last) |
|
return content_str_out, meta_first_out, meta_last_out |
|
|
|
def merge_dicts_except_source(d1, d2): |
|
merged = {} |
|
for key in d1: |
|
if key != "source": |
|
merged[key] = str(d1[key]) + " to " + str(d2[key]) |
|
else: |
|
merged[key] = d1[key] |
|
return merged |
|
|
|
def merge_two_lists_of_dicts(list1, list2): |
|
return [merge_dicts_except_source(d1, d2) for d1, d2 in zip(list1, list2)] |
|
|
|
|
|
vstore_docs = get_docs_from_vstore(vectorstore) |
|
doc_sources = {doc.metadata['source'] for doc, _ in docs} |
|
vstore_docs = [(k, v) for k, v in vstore_docs if v.metadata.get('source') in doc_sources] |
|
|
|
|
|
vstore_by_source = defaultdict(list) |
|
for k, v in vstore_docs: |
|
vstore_by_source[v.metadata['source']].append((k, v)) |
|
|
|
expanded_docs = [] |
|
for doc, score in docs: |
|
search_source = doc.metadata['source'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
search_section = doc.metadata['page_section'] |
|
parent_vstore_meta_section = [doc.metadata['page_section'] for _, doc in vstore_by_source[search_source]] |
|
search_index = parent_vstore_meta_section.index(search_section) if search_section in parent_vstore_meta_section else -1 |
|
|
|
content_str, meta_first, meta_last = get_parent_content_and_meta(vstore_by_source[search_source], width, search_index) |
|
meta_full = merge_two_lists_of_dicts(meta_first, meta_last) |
|
|
|
expanded_doc = (Document(page_content=content_str[0], metadata=meta_full[0]), score) |
|
expanded_docs.append(expanded_doc) |
|
|
|
doc_df = pd.DataFrame() |
|
|
|
doc_df = create_doc_df(expanded_docs) |
|
|
|
return expanded_docs, doc_df |
|
|
|
def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hlt_chunk_size, hlt_strat:List=hlt_strat, hlt_overlap:int=hlt_overlap) -> str: |
|
""" |
|
Highlights occurrences of search_text within full_text. |
|
|
|
Parameters: |
|
- search_text (str): The text to be searched for within full_text. |
|
- full_text (str): The text within which search_text occurrences will be highlighted. |
|
|
|
Returns: |
|
- str: A string with occurrences of search_text highlighted. |
|
|
|
Example: |
|
>>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.") |
|
'Hello, <mark style="color:black;">world</mark>! This is a test. Another <mark style="color:black;">world</mark> awaits.' |
|
""" |
|
|
|
def extract_text_from_input(text, i=0): |
|
if isinstance(text, str): |
|
return text.replace(" ", " ").strip() |
|
elif isinstance(text, list): |
|
return text[i][0].replace(" ", " ").strip() |
|
else: |
|
return "" |
|
|
|
def extract_search_text_from_input(text): |
|
if isinstance(text, str): |
|
return text.replace(" ", " ").strip() |
|
elif isinstance(text, list): |
|
return text[-1][1].replace(" ", " ").strip() |
|
else: |
|
return "" |
|
|
|
full_text = extract_text_from_input(full_text) |
|
search_text = extract_search_text_from_input(search_text) |
|
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=hlt_chunk_size, |
|
separators=hlt_strat, |
|
chunk_overlap=hlt_overlap, |
|
) |
|
sections = text_splitter.split_text(search_text) |
|
|
|
found_positions = {} |
|
for x in sections: |
|
text_start_pos = 0 |
|
while text_start_pos != -1: |
|
text_start_pos = full_text.find(x, text_start_pos) |
|
if text_start_pos != -1: |
|
found_positions[text_start_pos] = text_start_pos + len(x) |
|
text_start_pos += 1 |
|
|
|
|
|
sorted_starts = sorted(found_positions.keys()) |
|
combined_positions = [] |
|
if sorted_starts: |
|
current_start, current_end = sorted_starts[0], found_positions[sorted_starts[0]] |
|
for start in sorted_starts[1:]: |
|
if start <= (current_end + 10): |
|
current_end = max(current_end, found_positions[start]) |
|
else: |
|
combined_positions.append((current_start, current_end)) |
|
current_start, current_end = start, found_positions[start] |
|
combined_positions.append((current_start, current_end)) |
|
|
|
|
|
pos_tokens = [] |
|
prev_end = 0 |
|
for start, end in combined_positions: |
|
if end-start > 15: |
|
pos_tokens.append(full_text[prev_end:start]) |
|
pos_tokens.append('<mark style="color:black;">' + full_text[start:end] + '</mark>') |
|
prev_end = end |
|
pos_tokens.append(full_text[prev_end:]) |
|
|
|
return "".join(pos_tokens) |
|
|
|
|
|
|
|
|
|
def clear_chat(chat_history_state, sources, chat_message, current_topic): |
|
chat_history_state = [] |
|
sources = '' |
|
chat_message = '' |
|
current_topic = '' |
|
|
|
return chat_history_state, sources, chat_message, current_topic |
|
|
|
|
|
|
|
|
|
def remove_q_stopwords(question): |
|
|
|
text = question.lower() |
|
|
|
|
|
text = re.sub('[0-9]', '', text) |
|
|
|
tokenizer = RegexpTokenizer(r'\w+') |
|
text_tokens = tokenizer.tokenize(text) |
|
|
|
tokens_without_sw = [word for word in text_tokens if not word in stopwords] |
|
|
|
|
|
ordered_tokens = set() |
|
result = [] |
|
for word in tokens_without_sw: |
|
if word not in ordered_tokens: |
|
ordered_tokens.add(word) |
|
result.append(word) |
|
|
|
|
|
|
|
new_question_keywords = ' '.join(result) |
|
return new_question_keywords |
|
|
|
def remove_q_ner_extractor(question): |
|
|
|
predict_out = ner_model.predict(question) |
|
|
|
|
|
|
|
predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out] |
|
|
|
|
|
ordered_tokens = set() |
|
result = [] |
|
for word in predict_tokens: |
|
if word not in ordered_tokens: |
|
ordered_tokens.add(word) |
|
result.append(word) |
|
|
|
|
|
|
|
new_question_keywords = ' '.join(result).lower() |
|
return new_question_keywords |
|
|
|
def apply_lemmatize(text, wnl=WordNetLemmatizer()): |
|
|
|
def prep_for_lemma(text): |
|
|
|
|
|
text = re.sub('[0-9]', '', text) |
|
print(text) |
|
|
|
tokenizer = RegexpTokenizer(r'\w+') |
|
text_tokens = tokenizer.tokenize(text) |
|
|
|
|
|
return text_tokens |
|
|
|
tokens = prep_for_lemma(text) |
|
|
|
def lem_word(word): |
|
|
|
if len(word) > 3: out_word = wnl.lemmatize(word) |
|
else: out_word = word |
|
|
|
return out_word |
|
|
|
return [lem_word(token) for token in tokens] |
|
|
|
def keybert_keywords(text, n, kw_model): |
|
tokens_lemma = apply_lemmatize(text) |
|
lemmatised_text = ' '.join(tokens_lemma) |
|
|
|
keywords_text = KeyBERT(model=kw_model).extract_keywords(lemmatised_text, stop_words='english', top_n=n, |
|
keyphrase_ngram_range=(1, 1)) |
|
keywords_list = [item[0] for item in keywords_text] |
|
|
|
return keywords_list |
|
|
|
|
|
def turn_off_interactivity(user_message, history): |
|
return gr.update(value="", interactive=False), history + [[user_message, None]] |
|
|
|
def restore_interactivity(): |
|
return gr.update(interactive=True) |
|
|
|
def update_message(dropdown_value): |
|
return gr.Textbox.update(value=dropdown_value) |
|
|
|
def hide_block(): |
|
return gr.Radio.update(visible=False) |