File size: 5,142 Bytes
ddf3104
6cc6b0d
ddf3104
6cc6b0d
 
 
 
 
 
 
ddf3104
 
 
 
 
6cc6b0d
 
ddf3104
 
 
 
6cc6b0d
ddf3104
 
 
 
6cc6b0d
ddf3104
6cc6b0d
 
 
 
 
 
ddf3104
 
 
 
6cc6b0d
 
ddf3104
320f164
6cc6b0d
 
 
ddf3104
6cc6b0d
ddf3104
 
320f164
ddf3104
 
 
 
 
 
 
6cc6b0d
 
 
 
 
ddf3104
 
320f164
ddf3104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320f164
ddf3104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320f164
 
e1f0b39
6cc6b0d
ddf3104
 
 
 
320f164
ddf3104
 
320f164
e1f0b39
320f164
6cc6b0d
ddf3104
 
e1f0b39
6ae3bc4
 
ddf3104
6ae3bc4
ddf3104
 
6ae3bc4
 
6e1bef3
ddf3104
6cc6b0d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import PyPDF2
import pandas as pd
import warnings
import re
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import torch
import gradio as gr
from typing import Union
import numpy as np
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from dotenv import load_dotenv, find_dotenv

warnings.filterwarnings("ignore")

# Load environment variables
load_dotenv(find_dotenv())
ASTRADB_TOKEN = os.getenv("ASTRADB_TOKEN")
ASTRADB_API_ENDPOINT = os.getenv("ASTRADB_API_ENDPOINT")

# AstraDB connection setup using token and endpoint
auth_provider = PlainTextAuthProvider(username="token", password=ASTRADB_TOKEN)
cluster = Cluster([ASTRADB_API_ENDPOINT], auth_provider=auth_provider)
session = cluster.connect("your_keyspace_name")

# Load DPR models and tokenizers
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

def process_pdfs(parent_dir: Union[str, list]):
    """Processes the PDF files and returns a dataframe with the text of each page in a different line."""
    df = pd.DataFrame(columns=["title", "text"])
    if type(parent_dir) == str:
        parent_dir = [parent_dir]
    for file_path in parent_dir:
        if ".pdf" not in file_path:  # Skip non-pdf files
            raise Exception("only pdf files are supported")
        pdfFileObj = open(file_path, 'rb')
        pdfReader = PyPDF2.PdfReader(pdfFileObj)
        num_pages = len(pdfReader.pages)
        for i in range(num_pages):
            pageObj = pdfReader.pages[i]
            txt = pageObj.extract_text().replace("\n", "").replace("\t", "")
            txt = re.sub(r" +", " ", txt)  # Strip extra space
            file_name = file_path.split("/")[-1]
            if len(txt) < 512:
                new_data = pd.DataFrame([[f"{file_name}-page-{i}", txt]], columns=["title", "text"])
                df = pd.concat([df, new_data], ignore_index=True)
            else:
                while len(txt) > 512:
                    new_data = pd.DataFrame([[f"{file_name}-page-{i}", txt[:512]]], columns=["title", "text"])
                    df = pd.concat([df, new_data], ignore_index=True)
                    txt = txt[512:]
        pdfFileObj.close()
    return df

def process_dataset(df):
    """Processes the dataframe and stores embeddings in AstraDB."""
    if len(df) == 0:
        raise Exception("empty pdf files, or can't read text from them")

    for _, row in df.iterrows():
        title = row['title']
        text = row['text']
        tokens = ctx_tokenizer(text, return_tensors="pt")
        embed = ctx_encoder(**tokens)[0][0].detach().numpy().tolist()

        query = "INSERT INTO your_table_name (title, text, embeddings) VALUES (%s, %s, %s)"
        session.execute(query, (title, text, embed))

    return df

def search(query, k=3):
    """Searches the query in the database and returns the k most similar."""
    try:
        tokens = q_tokenizer(query, return_tensors="pt")
        query_embed = q_encoder(**tokens)[0][0].detach().numpy().tolist()

        # Perform vector search in AstraDB
        query = """
        SELECT title, text, embeddings
        FROM your_table_name
        ORDER BY embeddings ANN OF %s LIMIT %s
        """
        rows = session.execute(query, (query_embed, k))

        retrieved_examples = []
        for row in rows:
            retrieved_examples.append({
                "title": row.title,
                "text": row.text,
                "embeddings": np.array(row.embeddings)
            })

        out = f"""**title** : {retrieved_examples[0]["title"]},\ncontent: {retrieved_examples[0]["text"]}\n\n\n**similar resources:** {[example["title"] for example in retrieved_examples]}
        """
    except Exception as e:
        out = f"error in search: {e}"
    return out

def predict(query, file_paths, k=3):
    """Predicts the most similar files to the query."""
    try:
        df = process_pdfs(file_paths)
        process_dataset(df)
        out = search(query, k=k)
    except Exception as e:
        out = f"error in predict: {e}"
    return out

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center'> PDF Search Engine </h1>")
    with gr.Row():
        with gr.Column(): 
            files = gr.Files(label="Upload PDFs", type="filepath", file_count="multiple")
            query = gr.Text(label="query")
            with gr.Accordion("number of references", open=False):
                k = gr.Number(value=3, show_label=False, precision=0, minimum=1, container=False)
            button = gr.Button("search")
        with gr.Column():
            output = gr.Markdown(label="output")
    button.click(predict, [query, files, k], outputs=output)

demo.launch()