File size: 4,672 Bytes
6cc6b0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1f0b39
6cc6b0d
 
 
 
 
 
 
 
a023810
320f164
6cc6b0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320f164
6cc6b0d
6ae3bc4
 
6cc6b0d
 
6ae3bc4
 
6cc6b0d
 
 
 
 
 
 
 
e1f0b39
 
 
 
 
 
6cc6b0d
 
 
320f164
 
6cc6b0d
 
 
 
 
 
 
320f164
 
e1f0b39
320f164
6e1bef3
320f164
 
e1f0b39
6cc6b0d
 
 
 
320f164
 
 
 
 
e1f0b39
320f164
6cc6b0d
 
e1f0b39
6ae3bc4
 
 
 
 
4e12ec0
6ae3bc4
 
6e1bef3
6cc6b0d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os 
import PyPDF2
import pandas as pd 
import warnings
import re
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import torch
import gradio as gr
from typing import Union
from datasets import Dataset
warnings.filterwarnings("ignore")



ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")


q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")



def process_pdfs(parent_dir: Union[str,list]):
    """ processess the PDF files and returns a dataframe with the text of each page in a
    different line"""  # XD
    # creating a pdf file object
    df = pd.DataFrame(columns = ["title","text"])
    if type(parent_dir) == str :
        parent_dir = [parent_dir]
    for file_path in parent_dir:
        if ".pdf" not in file_path : # skip non pdf files
            raise Exception("only pdf files are supported")
        # creating a pdf file object
        pdfFileObj = open(file_path, 'rb')
        
        # creating a pdf reader object
        pdfReader = PyPDF2.PdfReader(pdfFileObj)
        # printing number of pages in pdf file
        num_pages = len(pdfReader.pages)
        for i in range(num_pages) : 
            pageObj = pdfReader.pages[i]
            # extracting text from page
            txt = pageObj.extract_text()
            txt = txt.replace("\n","") # strip return to line
            txt = txt.replace("\t","") # strip tabs
            txt = re.sub(r" +"," ",txt) # strip extra space
            # 512 is related to the positional encoding "facebook/dpr-ctx_encoder-single-nq-base" model
            file_name = file_path.split("/")[-1]
            if len(txt) < 512 :
                new_data = pd.DataFrame([[f"{file_name}-page-{i}",txt]],columns=["title","text"])
                df = pd.concat([df,new_data],ignore_index=True)
            else :
                while len(txt) > 512 :
                    new_data = pd.DataFrame([[f"{file_name}-page-{i}",txt[:512]]],columns=["title","text"])
                    df = pd.concat([df,new_data],ignore_index=True)
                    txt = txt[512:]
        
        # closing the pdf file object
        pdfFileObj.close()
    return df

def process(example):
  """process the bathces of the dataset and returns the embeddings"""
  try : 
    tokens = ctx_tokenizer(example["text"], return_tensors="pt")
    embed = ctx_encoder(**tokens)[0][0].detach().numpy()
    return {'embeddings': embed}
  except Exception as e:
    raise Exception(f"error in process: {e}")

def process_dataset(df):
    """processess the dataframe and returns a dataset variable"""
    if len(df) == 0 : 
        raise Exception("empty pdf files, or can't read text from them")
    ds = Dataset.from_pandas(df)
    ds = ds.map(process) 
    ds.add_faiss_index(column='embeddings') # add faiss index
    return ds

def search(query, ds, k=3):
    """searches the query in the dataset and returns the k most similar"""
    try :
        tokens = q_tokenizer(query, return_tensors="pt")
        query_embed = q_encoder(**tokens)[0][0].detach().numpy()
        scores, retrieved_examples = ds.get_nearest_examples("embeddings", query_embed, k=k)
        out = f"""**title** : {retrieved_examples["title"][0]},\ncontent: {retrieved_examples["text"][0]}\n\n\n**similar resources:** {retrieved_examples["title"]}
        """
    except Exception as e:
        out = f"error in search: {e}"
    return out
    
def predict(query,file_paths, k=3):
    """predicts the most similar files to the query"""
    try :
        df = process_pdfs(file_paths)
        ds = process_dataset(df)
        out =  search(query,ds,k=k)
    except Exception as e:
        out = f"error in predict: {e}"
    return out

with gr.Blocks() as demo : 
    gr.Markdown("<h1 style='text-align: center'> PDF Search Engine </h1>")
    with gr.Row():
        with gr.Column(): 
            files = gr.Files(label="Upload PDFs",type="filepath",file_count="multiple")
            query = gr.Text(label="query")
            with gr.Accordion("number of references",open=False):
                k = gr.Number(value=3,show_label=False,precision=0,minimum=1,container=False)
            button = gr.Button("search")
        with gr.Column():
            output = gr.Markdown(label="output")
    button.click(predict, [query,files,k],outputs=output)

demo.launch()