File size: 8,806 Bytes
15761f9
 
 
 
 
 
 
 
 
 
 
 
e9b803a
15761f9
 
 
 
 
 
 
e9b803a
e610b64
 
1c201ca
e610b64
a61b64b
 
 
 
 
 
 
 
 
e610b64
 
 
0909f8b
 
1c201ca
 
0909f8b
e610b64
1c201ca
e9b803a
0909f8b
1c201ca
 
e9b803a
1c201ca
 
0909f8b
 
 
1c201ca
 
e610b64
a61b64b
 
 
0909f8b
e610b64
a61b64b
 
 
e610b64
1c201ca
a61b64b
 
 
 
 
 
0909f8b
 
1c201ca
 
 
 
0909f8b
 
 
 
15761f9
6985019
15761f9
 
 
6985019
15761f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import streamlit as st
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import lancedb
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import time
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# from google_drive_downloader import GoogleDriveDownloader as gdd
# Download NLTK resources if not already downloaded
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')


# --------------------------- Dynamic Download of Large Files --------------------------- #
import gdown
import zipfile
import os
import shutil
# Function to download and extract folder
def download_and_extract_gdrive_finetuned(file_id, destination, extract_to):
    # Download the zip file
    gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)
    
    # Extract the zip file
    with zipfile.ZipFile(destination, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    os.remove(destination)  # Clean up the downloaded zip file

def download_and_extract_gdrive(file_id, destination, extract_to):
    # Download the zip file
    gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)

    # Use a temporary directory for extraction
    temp_dir = "./temp_extract"
    os.makedirs(temp_dir, exist_ok=True)

    with zipfile.ZipFile(destination, 'r') as zip_ref:
        zip_ref.extractall(temp_dir)

    # Ensure files are moved correctly
    if not os.path.exists(extract_to):
        os.makedirs(extract_to, exist_ok=True)

    for item in os.listdir(temp_dir):
        item_path = os.path.join(temp_dir, item)
        shutil.move(item_path, os.path.join(extract_to, item))

    # Cleanup
    shutil.rmtree(temp_dir)
    os.remove(destination)


# Download and extract LanceDB and fine-tuned model
st.info("Downloading and setting up necessary data. This might take a while...")

download_and_extract_gdrive(
    file_id="1Qnb8bs_NXWlhDwGoswOgsp2DiLBMbfSY",  # Replace with the actual Google Drive file ID
    destination="lancedb_directory_main",
    extract_to="./"
)

download_and_extract_gdrive_finetuned(
    file_id="1_9VVuN_P3zsTBYzg0lAeh4ghd9zhXS3w",  # Replace with the actual Google Drive file ID
    destination="finetuned_all_minilm_l6_v2",
    extract_to="./"
)
# # --------------------------- Load the LanceDB Table and Models --------------------------- #
# Validate extracted files
expected_files = [
    "enhanced_papers_pretrained_1.lance",
    "enhanced_papers_pretrained_2.lance",
    "enhanced_papers_finetuned.lance"
]
for file in expected_files:
    file_path = os.path.join("./lancedb_directory_main", file)
    if not os.path.isfile(file_path):
        raise FileNotFoundError(f"Expected file is missing: {file_path}")
# Connect to LanceDB
DB_PATH = "./lancedb_directory_main"
TABLE_NAME_1 = "enhanced_papers_pretrained_1"
TABLE_NAME_2 = "enhanced_papers_pretrained_2"
TABLE_NAME_3 = "enhanced_papers_finetuned"
os.makedirs(DB_PATH, exist_ok=True)
db = lancedb.connect(DB_PATH)
table1 = db.open_table(TABLE_NAME_1)
table2 = db.open_table(TABLE_NAME_2)
table3 = db.open_table(TABLE_NAME_3)

# Load the SentenceTransformer models
embedding_models = {
    "all-MiniLM-L6-v2": SentenceTransformer('all-MiniLM-L6-v2'),
    "allenai-specter": SentenceTransformer('allenai-specter'),
    "finetuned_all_minilm_l6_v2": SentenceTransformer('./finetuned_all_minilm_l6_v2')
}

model_tables = {
    "all-MiniLM-L6-v2": table1,
    "allenai-specter": table2,
    "finetuned_all_minilm_l6_v2": table3
}

# Load the tokenizer and summarization model for RAG-based explanations
MODEL_NAME = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
rag_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
rag_pipeline = pipeline("text2text-generation", model=rag_model, tokenizer=tokenizer, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# --------------------------- Streamlit UI Components --------------------------- #

st.title("Research Paper Recommendation System with RAG-based Explanations")

# Initialize stopwords and lemmatizer
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()

# Function to clean text
def clean_text(text):
    if pd.isnull(text):
        return ""
    # Lowercasing
    text = text.lower()
    # Remove special characters and punctuation
    text = re.sub(r'[^a-z0-9\s]', '', text)
    # Remove extra whitespace and newlines
    text = re.sub(r'\s+', ' ', text).strip()
    # Tokenize and remove stopwords, then lemmatize
    tokens = text.split()
    tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
    return ' '.join(tokens)

# Input abstract from the user
user_abstract = st.text_area("Enter the abstract of your paper:", height=200)

# Preprocess the user input abstract
user_abstract = clean_text(user_abstract)

# Number of recommendations slider
k = st.slider("Select the number of recommendations (k):", min_value=1, max_value=20, value=5)

# Model selection dropdown
selected_model_name = st.sidebar.selectbox("Select the embedding model:", list(embedding_models.keys()))

# Fetch unique metadata values for filters
def get_unique_values(table, column):
    df = table.to_pandas()
    return sorted(df[column].dropna().unique())

table = model_tables[selected_model_name]
categories = get_unique_values(table, 'categories')
authors = get_unique_values(table, 'authors')

# Metadata filters
st.sidebar.header("Filter Recommendations by Metadata")
filter_category = st.sidebar.selectbox("Filter by Category (optional):", [""] + categories)
filter_author = st.sidebar.selectbox("Filter by Author (optional):", [""] + authors)

# --------------------------- Helper Functions --------------------------- #

def generate_explanation(user_abstract, recommended_title, recommended_authors, recommended_abstract, max_input_length=512, max_output_length=200):
    prompt = (
        f"User's Input:\n{user_abstract}\n\n"
        f"Recommended Paper:\n"
        f"Title: {recommended_title}\n"
        f"Authors: {recommended_authors}\n"
        f"Abstract: {recommended_abstract}\n\n"
        "Explain briefly, how the recommended paper is relevant to the user's input"
    )
    try:
        explanation = rag_pipeline(
            prompt,
            max_length=max_output_length,
            min_length=50,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            truncation=True
        )[0]['generated_text']
        return explanation
    except Exception as e:
        return f"Error during generation: {e}"

def post_process_explanation(text):
    sentences = list(dict.fromkeys(text.split('. ')))
    return '. '.join(sentences).strip()

def get_recommendations(table, embedding_model, model_name):
    with st.spinner(f"Generating embedding for your abstract using {model_name}..."):
        user_embedding = embedding_model.encode(user_abstract, convert_to_tensor=True).cpu().numpy()

    # Perform similarity search
    query = table.search(user_embedding).metric("cosine").limit(k)

    if filter_category:
        query = query.where(f"categories == '{filter_category}'")
    if filter_author:
        query = query.where(f"authors LIKE '%{filter_author}%'")

    return query.to_pandas()

# --------------------------- Main Logic for Recommendations --------------------------- #

if st.button("Get Recommendations"):
    if not user_abstract:
        st.error("Please enter an abstract to proceed.")
    else:
        embedding_model = embedding_models[selected_model_name]
        table = model_tables[selected_model_name]

        st.header(f"Recommendations using {selected_model_name}")
        recommendations = get_recommendations(table, embedding_model, selected_model_name)

        if recommendations.empty:
            st.warning(f"No recommendations found for {selected_model_name} based on the current filters.")
        else:
            st.success(f"Top {len(recommendations)} Recommendations from {selected_model_name}:")

            for idx, row in recommendations.iterrows():
                st.write(f"### {idx + 1}. {row['title']}")
                st.write(f"**Category:** {row['categories']}")
                st.write(f"**Authors:** {row['authors']}")
                st.write(f"**Abstract:** {row['abstract']}")
                st.write(f"**Last Updated:** {row['update_date']}")
                st.write("---")

                explanation = generate_explanation(user_abstract, row['title'], row['authors'], row['abstract'])
                explanation = post_process_explanation(explanation)
                st.write(f"**Explanation:** {explanation}")
                st.write("---")