File size: 13,151 Bytes
78d71d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import re
import os
from typing import TypeVar, List
import pandas as pd


# Model packages
import torch.cuda

# Alternative model sources
#from dataclasses import asdict, dataclass

# Langchain functions
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document

# For keyword extraction (not currently used)
#import nltk
#nltk.download('wordnet')
from nltk.corpus import stopwords
from nltk.tokenize import RegexpTokenizer
from nltk.stem import WordNetLemmatizer

# For Name Entity Recognition model
#from span_marker import SpanMarkerModel # Not currently used


import gradio as gr

torch.cuda.empty_cache()

PandasDataFrame = TypeVar('pd.core.frame.DataFrame')

embeddings = None  # global variable setup
vectorstore = None # global variable setup
model_type = None # global variable setup

max_memory_length = 0 # How long should the memory of the conversation last?

full_text = "" # Define dummy source text (full text) just to enable highlight function to load

model = [] # Define empty list for model functions to run
tokenizer = [] # Define empty list for model functions to run

## Highlight text constants
hlt_chunk_size = 12
hlt_strat = [" ", ". ", "! ", "? ", ": ", "\n\n", "\n", ", "]
hlt_overlap = 4

## Initialise NER model ##
ner_model = []#SpanMarkerModel.from_pretrained("tomaarsen/span-marker-mbert-base-multinerd") # Not currently used


# Currently set gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
if torch.cuda.is_available():
    torch_device = "cuda"
    gpu_layers = 0
else: 
    torch_device =  "cpu"
    gpu_layers = 0

print("Running on device:", torch_device)
threads = 6 #torch.get_num_threads()
print("CPU threads:", threads)

# Vectorstore funcs

# Prompt functions

def write_out_metadata_as_string(metadata_in):
    metadata_string = [f"{'  '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
    return metadata_string


def determine_file_type(file_path):
        """
        Determine the file type based on its extension.
    
        Parameters:
            file_path (str): Path to the file.
    
        Returns:
            str: File extension (e.g., '.pdf', '.docx', '.txt', '.html').
        """
        return os.path.splitext(file_path)[1].lower()


def create_doc_df(docs_keep_out):
    # Extract content and metadata from 'winning' passages.
            content=[]
            meta=[]
            meta_url=[]
            page_section=[]
            score=[]

            doc_df = pd.DataFrame()

            

            for item in docs_keep_out:
                content.append(item[0].page_content)
                meta.append(item[0].metadata)
                meta_url.append(item[0].metadata['source'])

                file_extension = determine_file_type(item[0].metadata['source'])
                if (file_extension != ".csv") & (file_extension != ".xlsx"):
                    page_section.append(item[0].metadata['page_section'])
                else: page_section.append("")
                score.append(item[1])       

            # Create df from 'winning' passages

            doc_df = pd.DataFrame(list(zip(content, meta, page_section, meta_url, score)),
               columns =['page_content', 'metadata', 'page_section', 'meta_url', 'score'])

            docs_content = doc_df['page_content'].astype(str)
            doc_df['full_url'] = "https://" + doc_df['meta_url'] 

            return doc_df


def get_expanded_passages(vectorstore, docs, width):

    """
    Extracts expanded passages based on given documents and a width for context.
    
    Parameters:
    - vectorstore: The primary data source.
    - docs: List of documents to be expanded.
    - width: Number of documents to expand around a given document for context.
    
    Returns:
    - expanded_docs: List of expanded Document objects.
    - doc_df: DataFrame representation of expanded_docs.
    """

    from collections import defaultdict
    
    def get_docs_from_vstore(vectorstore):
        vector = vectorstore.docstore._dict
        return list(vector.items())

    def extract_details(docs_list):
        docs_list_out = [tup[1] for tup in docs_list]
        content = [doc.page_content for doc in docs_list_out]
        meta = [doc.metadata for doc in docs_list_out]
        return ''.join(content), meta[0], meta[-1]
    
    def get_parent_content_and_meta(vstore_docs, width, target):
        #target_range = range(max(0, target - width), min(len(vstore_docs), target + width + 1))
        target_range = range(max(0, target), min(len(vstore_docs), target + width + 1)) # Now only selects extra passages AFTER the found passage
        parent_vstore_out = [vstore_docs[i] for i in target_range]
        
        content_str_out, meta_first_out, meta_last_out = [], [], []
        for _ in parent_vstore_out:
            content_str, meta_first, meta_last = extract_details(parent_vstore_out)
            content_str_out.append(content_str)
            meta_first_out.append(meta_first)
            meta_last_out.append(meta_last)
        return content_str_out, meta_first_out, meta_last_out

    def merge_dicts_except_source(d1, d2):
            merged = {}
            for key in d1:
                if key != "source":
                    merged[key] = str(d1[key]) + " to " + str(d2[key])
                else:
                    merged[key] = d1[key]  # or d2[key], based on preference
            return merged

    def merge_two_lists_of_dicts(list1, list2):
        return [merge_dicts_except_source(d1, d2) for d1, d2 in zip(list1, list2)]

    # Step 1: Filter vstore_docs
    vstore_docs = get_docs_from_vstore(vectorstore)
    doc_sources = {doc.metadata['source'] for doc, _ in docs}
    vstore_docs = [(k, v) for k, v in vstore_docs if v.metadata.get('source') in doc_sources]

    # Step 2: Group by source and proceed
    vstore_by_source = defaultdict(list)
    for k, v in vstore_docs:
        vstore_by_source[v.metadata['source']].append((k, v))
        
    expanded_docs = []
    for doc, score in docs:
        search_source = doc.metadata['source']
        

        #if file_type == ".csv" | file_type == ".xlsx":
        #     content_str, meta_first, meta_last = get_parent_content_and_meta(vstore_by_source[search_source], 0, search_index)

        #else:
        search_section = doc.metadata['page_section']
        parent_vstore_meta_section = [doc.metadata['page_section'] for _, doc in vstore_by_source[search_source]]
        search_index = parent_vstore_meta_section.index(search_section) if search_section in parent_vstore_meta_section else -1

        content_str, meta_first, meta_last = get_parent_content_and_meta(vstore_by_source[search_source], width, search_index)
        meta_full = merge_two_lists_of_dicts(meta_first, meta_last)

        expanded_doc = (Document(page_content=content_str[0], metadata=meta_full[0]), score)
        expanded_docs.append(expanded_doc)

    doc_df = pd.DataFrame()

    doc_df = create_doc_df(expanded_docs)  # Assuming you've defined the 'create_doc_df' function elsewhere

    return expanded_docs, doc_df

def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hlt_chunk_size, hlt_strat:List=hlt_strat, hlt_overlap:int=hlt_overlap) -> str:
    """
    Highlights occurrences of search_text within full_text.
    
    Parameters:
    - search_text (str): The text to be searched for within full_text.
    - full_text (str): The text within which search_text occurrences will be highlighted.
    
    Returns:
    - str: A string with occurrences of search_text highlighted.
    
    Example:
    >>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.")
    'Hello, <mark style="color:black;">world</mark>! This is a test. Another <mark style="color:black;">world</mark> awaits.'
    """

    def extract_text_from_input(text, i=0):
        if isinstance(text, str):
            return text.replace("  ", " ").strip()
        elif isinstance(text, list):
            return text[i][0].replace("  ", " ").strip()
        else:
            return ""

    def extract_search_text_from_input(text):
        if isinstance(text, str):
            return text.replace("  ", " ").strip()
        elif isinstance(text, list):
            return text[-1][1].replace("  ", " ").strip()
        else:
            return ""

    full_text = extract_text_from_input(full_text)
    search_text = extract_search_text_from_input(search_text)



    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=hlt_chunk_size,
        separators=hlt_strat,
        chunk_overlap=hlt_overlap,
    )
    sections = text_splitter.split_text(search_text)

    found_positions = {}
    for x in sections:
        text_start_pos = 0
        while text_start_pos != -1:
            text_start_pos = full_text.find(x, text_start_pos)
            if text_start_pos != -1:
                found_positions[text_start_pos] = text_start_pos + len(x)
                text_start_pos += 1

    # Combine overlapping or adjacent positions
    sorted_starts = sorted(found_positions.keys())
    combined_positions = []
    if sorted_starts:
        current_start, current_end = sorted_starts[0], found_positions[sorted_starts[0]]
        for start in sorted_starts[1:]:
            if start <= (current_end + 10):
                current_end = max(current_end, found_positions[start])
            else:
                combined_positions.append((current_start, current_end))
                current_start, current_end = start, found_positions[start]
        combined_positions.append((current_start, current_end))

    # Construct pos_tokens
    pos_tokens = []
    prev_end = 0
    for start, end in combined_positions:
        if end-start > 15: # Only combine if there is a significant amount of matched text. Avoids picking up single words like 'and' etc.
            pos_tokens.append(full_text[prev_end:start])
            pos_tokens.append('<mark style="color:black;">' + full_text[start:end] + '</mark>')
            prev_end = end
    pos_tokens.append(full_text[prev_end:])

    return "".join(pos_tokens)


# # Chat history functions

def clear_chat(chat_history_state, sources, chat_message, current_topic):
    chat_history_state = []
    sources = ''
    chat_message = ''
    current_topic = ''

    return chat_history_state, sources, chat_message, current_topic


# Keyword functions

def remove_q_stopwords(question): # Remove stopwords from question. Not used at the moment 
    # Prepare keywords from question by removing stopwords
    text = question.lower()

    # Remove numbers
    text = re.sub('[0-9]', '', text)

    tokenizer = RegexpTokenizer(r'\w+')
    text_tokens = tokenizer.tokenize(text)
    #text_tokens = word_tokenize(text)
    tokens_without_sw = [word for word in text_tokens if not word in stopwords]

    # Remove duplicate words while preserving order
    ordered_tokens = set()
    result = []
    for word in tokens_without_sw:
        if word not in ordered_tokens:
            ordered_tokens.add(word)
            result.append(word)
     


    new_question_keywords = ' '.join(result)
    return new_question_keywords

def remove_q_ner_extractor(question):
    
    predict_out = ner_model.predict(question)



    predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out]

    # Remove duplicate words while preserving order
    ordered_tokens = set()
    result = []
    for word in predict_tokens:
        if word not in ordered_tokens:
            ordered_tokens.add(word)
            result.append(word)
     


    new_question_keywords = ' '.join(result).lower()
    return new_question_keywords

def apply_lemmatize(text, wnl=WordNetLemmatizer()):

    def prep_for_lemma(text):

        # Remove numbers
        text = re.sub('[0-9]', '', text)
        print(text)

        tokenizer = RegexpTokenizer(r'\w+')
        text_tokens = tokenizer.tokenize(text)
        #text_tokens = word_tokenize(text)

        return text_tokens

    tokens = prep_for_lemma(text)

    def lem_word(word):
    
        if len(word) > 3: out_word = wnl.lemmatize(word)
        else: out_word = word

        return out_word

    return [lem_word(token) for token in tokens]

def keybert_keywords(text, n, kw_model):
    tokens_lemma = apply_lemmatize(text)
    lemmatised_text = ' '.join(tokens_lemma)

    keywords_text = KeyBERT(model=kw_model).extract_keywords(lemmatised_text, stop_words='english', top_n=n, 
                                                   keyphrase_ngram_range=(1, 1))
    keywords_list = [item[0] for item in keywords_text]

    return keywords_list
    
# Gradio functions
def turn_off_interactivity(user_message, history):
        return gr.update(value="", interactive=False), history + [[user_message, None]]

def restore_interactivity():
        return gr.update(interactive=True)

def update_message(dropdown_value):
        return gr.Textbox.update(value=dropdown_value)

def hide_block():
        return gr.Radio.update(visible=False)