Spaces:
Sleeping
Sleeping
File size: 4,672 Bytes
6cc6b0d e1f0b39 6cc6b0d a023810 320f164 6cc6b0d 320f164 6cc6b0d 6ae3bc4 6cc6b0d 6ae3bc4 6cc6b0d e1f0b39 6cc6b0d 320f164 6cc6b0d 320f164 e1f0b39 320f164 6e1bef3 320f164 e1f0b39 6cc6b0d 320f164 e1f0b39 320f164 6cc6b0d e1f0b39 6ae3bc4 4e12ec0 6ae3bc4 6e1bef3 6cc6b0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import PyPDF2
import pandas as pd
import warnings
import re
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import torch
import gradio as gr
from typing import Union
from datasets import Dataset
warnings.filterwarnings("ignore")
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
def process_pdfs(parent_dir: Union[str,list]):
""" processess the PDF files and returns a dataframe with the text of each page in a
different line""" # XD
# creating a pdf file object
df = pd.DataFrame(columns = ["title","text"])
if type(parent_dir) == str :
parent_dir = [parent_dir]
for file_path in parent_dir:
if ".pdf" not in file_path : # skip non pdf files
raise Exception("only pdf files are supported")
# creating a pdf file object
pdfFileObj = open(file_path, 'rb')
# creating a pdf reader object
pdfReader = PyPDF2.PdfReader(pdfFileObj)
# printing number of pages in pdf file
num_pages = len(pdfReader.pages)
for i in range(num_pages) :
pageObj = pdfReader.pages[i]
# extracting text from page
txt = pageObj.extract_text()
txt = txt.replace("\n","") # strip return to line
txt = txt.replace("\t","") # strip tabs
txt = re.sub(r" +"," ",txt) # strip extra space
# 512 is related to the positional encoding "facebook/dpr-ctx_encoder-single-nq-base" model
file_name = file_path.split("/")[-1]
if len(txt) < 512 :
new_data = pd.DataFrame([[f"{file_name}-page-{i}",txt]],columns=["title","text"])
df = pd.concat([df,new_data],ignore_index=True)
else :
while len(txt) > 512 :
new_data = pd.DataFrame([[f"{file_name}-page-{i}",txt[:512]]],columns=["title","text"])
df = pd.concat([df,new_data],ignore_index=True)
txt = txt[512:]
# closing the pdf file object
pdfFileObj.close()
return df
def process(example):
"""process the bathces of the dataset and returns the embeddings"""
try :
tokens = ctx_tokenizer(example["text"], return_tensors="pt")
embed = ctx_encoder(**tokens)[0][0].detach().numpy()
return {'embeddings': embed}
except Exception as e:
raise Exception(f"error in process: {e}")
def process_dataset(df):
"""processess the dataframe and returns a dataset variable"""
if len(df) == 0 :
raise Exception("empty pdf files, or can't read text from them")
ds = Dataset.from_pandas(df)
ds = ds.map(process)
ds.add_faiss_index(column='embeddings') # add faiss index
return ds
def search(query, ds, k=3):
"""searches the query in the dataset and returns the k most similar"""
try :
tokens = q_tokenizer(query, return_tensors="pt")
query_embed = q_encoder(**tokens)[0][0].detach().numpy()
scores, retrieved_examples = ds.get_nearest_examples("embeddings", query_embed, k=k)
out = f"""**title** : {retrieved_examples["title"][0]},\ncontent: {retrieved_examples["text"][0]}\n\n\n**similar resources:** {retrieved_examples["title"]}
"""
except Exception as e:
out = f"error in search: {e}"
return out
def predict(query,file_paths, k=3):
"""predicts the most similar files to the query"""
try :
df = process_pdfs(file_paths)
ds = process_dataset(df)
out = search(query,ds,k=k)
except Exception as e:
out = f"error in predict: {e}"
return out
with gr.Blocks() as demo :
gr.Markdown("<h1 style='text-align: center'> PDF Search Engine </h1>")
with gr.Row():
with gr.Column():
files = gr.Files(label="Upload PDFs",type="filepath",file_count="multiple")
query = gr.Text(label="query")
with gr.Accordion("number of references",open=False):
k = gr.Number(value=3,show_label=False,precision=0,minimum=1,container=False)
button = gr.Button("search")
with gr.Column():
output = gr.Markdown(label="output")
button.click(predict, [query,files,k],outputs=output)
demo.launch()
|