File size: 7,618 Bytes
15761f9
 
 
 
 
 
 
 
 
 
 
 
e9b803a
15761f9
 
 
 
 
 
 
e9b803a
e610b64
 
6db8e2c
e610b64
6db8e2c
a61b64b
 
 
 
 
 
 
 
 
 
0909f8b
e610b64
a61b64b
2af09b8
a61b64b
e610b64
1c201ca
6db8e2c
a61b64b
2af09b8
a61b64b
 
 
6db8e2c
15761f9
6985019
15761f9
 
 
6db8e2c
15761f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import streamlit as st
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import lancedb
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import time
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# from google_drive_downloader import GoogleDriveDownloader as gdd
# Download NLTK resources if not already downloaded
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')


# --------------------------- Dynamic Download of Large Files --------------------------- #
import gdown
import zipfile
import os

# Function to download and extract folder
def download_and_extract_gdrive(file_id, destination, extract_to):
    # Download the zip file
    gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)
    
    # Extract the zip file
    with zipfile.ZipFile(destination, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    os.remove(destination)  # Clean up the downloaded zip file

# Download and extract LanceDB and fine-tuned model
st.info("Downloading and setting up necessary data. This might take a while...")

download_and_extract_gdrive(
    file_id="1Qnb8bs_NXWlhDwGoswOgsp2DiLBMbfSY",  # Replace with the actual Google Drive file ID
    destination="lancedb_directory_main",
    extract_to="./"
)

download_and_extract_gdrive(
    file_id="1_9VVuN_P3zsTBYzg0lAeh4ghd9zhXS3w",  # Replace with the actual Google Drive file ID
    destination="finetuned_all_minilm_l6_v2",
    extract_to="./"
)
# # --------------------------- Load the LanceDB Table and Models --------------------------- #

# Connect to LanceDB
DB_PATH = "./lancedb_directory_main"
TABLE_NAME_1 = "enhanced_papers_pretrained_1"
TABLE_NAME_2 = "enhanced_papers_pretrained_2"
TABLE_NAME_3 = "enhanced_papers_finetuned"

db = lancedb.connect(DB_PATH)
table1 = db.open_table(TABLE_NAME_1)
table2 = db.open_table(TABLE_NAME_2)
table3 = db.open_table(TABLE_NAME_3)

# Load the SentenceTransformer models
embedding_models = {
    "all-MiniLM-L6-v2": SentenceTransformer('all-MiniLM-L6-v2'),
    "allenai-specter": SentenceTransformer('allenai-specter'),
    "finetuned_all_minilm_l6_v2": SentenceTransformer('./finetuned_all_minilm_l6_v2')
}

model_tables = {
    "all-MiniLM-L6-v2": table1,
    "allenai-specter": table2,
    "finetuned_all_minilm_l6_v2": table3
}

# Load the tokenizer and summarization model for RAG-based explanations
MODEL_NAME = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
rag_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
rag_pipeline = pipeline("text2text-generation", model=rag_model, tokenizer=tokenizer, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# --------------------------- Streamlit UI Components --------------------------- #

st.title("Research Paper Recommendation System with RAG-based Explanations")

# Initialize stopwords and lemmatizer
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()

# Function to clean text
def clean_text(text):
    if pd.isnull(text):
        return ""
    # Lowercasing
    text = text.lower()
    # Remove special characters and punctuation
    text = re.sub(r'[^a-z0-9\s]', '', text)
    # Remove extra whitespace and newlines
    text = re.sub(r'\s+', ' ', text).strip()
    # Tokenize and remove stopwords, then lemmatize
    tokens = text.split()
    tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
    return ' '.join(tokens)

# Input abstract from the user
user_abstract = st.text_area("Enter the abstract of your paper:", height=200)

# Preprocess the user input abstract
user_abstract = clean_text(user_abstract)

# Number of recommendations slider
k = st.slider("Select the number of recommendations (k):", min_value=1, max_value=20, value=5)

# Model selection dropdown
selected_model_name = st.sidebar.selectbox("Select the embedding model:", list(embedding_models.keys()))

# Fetch unique metadata values for filters
def get_unique_values(table, column):
    df = table.to_pandas()
    return sorted(df[column].dropna().unique())

table = model_tables[selected_model_name]
categories = get_unique_values(table, 'categories')
authors = get_unique_values(table, 'authors')

# Metadata filters
st.sidebar.header("Filter Recommendations by Metadata")
filter_category = st.sidebar.selectbox("Filter by Category (optional):", [""] + categories)
filter_author = st.sidebar.selectbox("Filter by Author (optional):", [""] + authors)

# --------------------------- Helper Functions --------------------------- #

def generate_explanation(user_abstract, recommended_title, recommended_authors, recommended_abstract, max_input_length=512, max_output_length=200):
    prompt = (
        f"User's Input:\n{user_abstract}\n\n"
        f"Recommended Paper:\n"
        f"Title: {recommended_title}\n"
        f"Authors: {recommended_authors}\n"
        f"Abstract: {recommended_abstract}\n\n"
        "Explain briefly, how the recommended paper is relevant to the user's input"
    )
    try:
        explanation = rag_pipeline(
            prompt,
            max_length=max_output_length,
            min_length=50,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            truncation=True
        )[0]['generated_text']
        return explanation
    except Exception as e:
        return f"Error during generation: {e}"

def post_process_explanation(text):
    sentences = list(dict.fromkeys(text.split('. ')))
    return '. '.join(sentences).strip()

def get_recommendations(table, embedding_model, model_name):
    with st.spinner(f"Generating embedding for your abstract using {model_name}..."):
        user_embedding = embedding_model.encode(user_abstract, convert_to_tensor=True).cpu().numpy()

    # Perform similarity search
    query = table.search(user_embedding).metric("cosine").limit(k)

    if filter_category:
        query = query.where(f"categories == '{filter_category}'")
    if filter_author:
        query = query.where(f"authors LIKE '%{filter_author}%'")

    return query.to_pandas()

# --------------------------- Main Logic for Recommendations --------------------------- #

if st.button("Get Recommendations"):
    if not user_abstract:
        st.error("Please enter an abstract to proceed.")
    else:
        embedding_model = embedding_models[selected_model_name]
        table = model_tables[selected_model_name]

        st.header(f"Recommendations using {selected_model_name}")
        recommendations = get_recommendations(table, embedding_model, selected_model_name)

        if recommendations.empty:
            st.warning(f"No recommendations found for {selected_model_name} based on the current filters.")
        else:
            st.success(f"Top {len(recommendations)} Recommendations from {selected_model_name}:")

            for idx, row in recommendations.iterrows():
                st.write(f"### {idx + 1}. {row['title']}")
                st.write(f"**Category:** {row['categories']}")
                st.write(f"**Authors:** {row['authors']}")
                st.write(f"**Abstract:** {row['abstract']}")
                st.write(f"**Last Updated:** {row['update_date']}")
                st.write("---")

                explanation = generate_explanation(user_abstract, row['title'], row['authors'], row['abstract'])
                explanation = post_process_explanation(explanation)
                st.write(f"**Explanation:** {explanation}")
                st.write("---")