File size: 8,806 Bytes
15761f9 e9b803a 15761f9 e9b803a e610b64 1c201ca e610b64 a61b64b e610b64 0909f8b 1c201ca 0909f8b e610b64 1c201ca e9b803a 0909f8b 1c201ca e9b803a 1c201ca 0909f8b 1c201ca e610b64 a61b64b 0909f8b e610b64 a61b64b e610b64 1c201ca a61b64b 0909f8b 1c201ca 0909f8b 15761f9 6985019 15761f9 6985019 15761f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
import streamlit as st
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import lancedb
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import time
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# from google_drive_downloader import GoogleDriveDownloader as gdd
# Download NLTK resources if not already downloaded
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
# --------------------------- Dynamic Download of Large Files --------------------------- #
import gdown
import zipfile
import os
import shutil
# Function to download and extract folder
def download_and_extract_gdrive_finetuned(file_id, destination, extract_to):
# Download the zip file
gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)
# Extract the zip file
with zipfile.ZipFile(destination, 'r') as zip_ref:
zip_ref.extractall(extract_to)
os.remove(destination) # Clean up the downloaded zip file
def download_and_extract_gdrive(file_id, destination, extract_to):
# Download the zip file
gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)
# Use a temporary directory for extraction
temp_dir = "./temp_extract"
os.makedirs(temp_dir, exist_ok=True)
with zipfile.ZipFile(destination, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# Ensure files are moved correctly
if not os.path.exists(extract_to):
os.makedirs(extract_to, exist_ok=True)
for item in os.listdir(temp_dir):
item_path = os.path.join(temp_dir, item)
shutil.move(item_path, os.path.join(extract_to, item))
# Cleanup
shutil.rmtree(temp_dir)
os.remove(destination)
# Download and extract LanceDB and fine-tuned model
st.info("Downloading and setting up necessary data. This might take a while...")
download_and_extract_gdrive(
file_id="1Qnb8bs_NXWlhDwGoswOgsp2DiLBMbfSY", # Replace with the actual Google Drive file ID
destination="lancedb_directory_main",
extract_to="./"
)
download_and_extract_gdrive_finetuned(
file_id="1_9VVuN_P3zsTBYzg0lAeh4ghd9zhXS3w", # Replace with the actual Google Drive file ID
destination="finetuned_all_minilm_l6_v2",
extract_to="./"
)
# # --------------------------- Load the LanceDB Table and Models --------------------------- #
# Validate extracted files
expected_files = [
"enhanced_papers_pretrained_1.lance",
"enhanced_papers_pretrained_2.lance",
"enhanced_papers_finetuned.lance"
]
for file in expected_files:
file_path = os.path.join("./lancedb_directory_main", file)
if not os.path.isfile(file_path):
raise FileNotFoundError(f"Expected file is missing: {file_path}")
# Connect to LanceDB
DB_PATH = "./lancedb_directory_main"
TABLE_NAME_1 = "enhanced_papers_pretrained_1"
TABLE_NAME_2 = "enhanced_papers_pretrained_2"
TABLE_NAME_3 = "enhanced_papers_finetuned"
os.makedirs(DB_PATH, exist_ok=True)
db = lancedb.connect(DB_PATH)
table1 = db.open_table(TABLE_NAME_1)
table2 = db.open_table(TABLE_NAME_2)
table3 = db.open_table(TABLE_NAME_3)
# Load the SentenceTransformer models
embedding_models = {
"all-MiniLM-L6-v2": SentenceTransformer('all-MiniLM-L6-v2'),
"allenai-specter": SentenceTransformer('allenai-specter'),
"finetuned_all_minilm_l6_v2": SentenceTransformer('./finetuned_all_minilm_l6_v2')
}
model_tables = {
"all-MiniLM-L6-v2": table1,
"allenai-specter": table2,
"finetuned_all_minilm_l6_v2": table3
}
# Load the tokenizer and summarization model for RAG-based explanations
MODEL_NAME = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
rag_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
rag_pipeline = pipeline("text2text-generation", model=rag_model, tokenizer=tokenizer, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# --------------------------- Streamlit UI Components --------------------------- #
st.title("Research Paper Recommendation System with RAG-based Explanations")
# Initialize stopwords and lemmatizer
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
# Function to clean text
def clean_text(text):
if pd.isnull(text):
return ""
# Lowercasing
text = text.lower()
# Remove special characters and punctuation
text = re.sub(r'[^a-z0-9\s]', '', text)
# Remove extra whitespace and newlines
text = re.sub(r'\s+', ' ', text).strip()
# Tokenize and remove stopwords, then lemmatize
tokens = text.split()
tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
return ' '.join(tokens)
# Input abstract from the user
user_abstract = st.text_area("Enter the abstract of your paper:", height=200)
# Preprocess the user input abstract
user_abstract = clean_text(user_abstract)
# Number of recommendations slider
k = st.slider("Select the number of recommendations (k):", min_value=1, max_value=20, value=5)
# Model selection dropdown
selected_model_name = st.sidebar.selectbox("Select the embedding model:", list(embedding_models.keys()))
# Fetch unique metadata values for filters
def get_unique_values(table, column):
df = table.to_pandas()
return sorted(df[column].dropna().unique())
table = model_tables[selected_model_name]
categories = get_unique_values(table, 'categories')
authors = get_unique_values(table, 'authors')
# Metadata filters
st.sidebar.header("Filter Recommendations by Metadata")
filter_category = st.sidebar.selectbox("Filter by Category (optional):", [""] + categories)
filter_author = st.sidebar.selectbox("Filter by Author (optional):", [""] + authors)
# --------------------------- Helper Functions --------------------------- #
def generate_explanation(user_abstract, recommended_title, recommended_authors, recommended_abstract, max_input_length=512, max_output_length=200):
prompt = (
f"User's Input:\n{user_abstract}\n\n"
f"Recommended Paper:\n"
f"Title: {recommended_title}\n"
f"Authors: {recommended_authors}\n"
f"Abstract: {recommended_abstract}\n\n"
"Explain briefly, how the recommended paper is relevant to the user's input"
)
try:
explanation = rag_pipeline(
prompt,
max_length=max_output_length,
min_length=50,
do_sample=True,
temperature=0.7,
top_p=0.9,
truncation=True
)[0]['generated_text']
return explanation
except Exception as e:
return f"Error during generation: {e}"
def post_process_explanation(text):
sentences = list(dict.fromkeys(text.split('. ')))
return '. '.join(sentences).strip()
def get_recommendations(table, embedding_model, model_name):
with st.spinner(f"Generating embedding for your abstract using {model_name}..."):
user_embedding = embedding_model.encode(user_abstract, convert_to_tensor=True).cpu().numpy()
# Perform similarity search
query = table.search(user_embedding).metric("cosine").limit(k)
if filter_category:
query = query.where(f"categories == '{filter_category}'")
if filter_author:
query = query.where(f"authors LIKE '%{filter_author}%'")
return query.to_pandas()
# --------------------------- Main Logic for Recommendations --------------------------- #
if st.button("Get Recommendations"):
if not user_abstract:
st.error("Please enter an abstract to proceed.")
else:
embedding_model = embedding_models[selected_model_name]
table = model_tables[selected_model_name]
st.header(f"Recommendations using {selected_model_name}")
recommendations = get_recommendations(table, embedding_model, selected_model_name)
if recommendations.empty:
st.warning(f"No recommendations found for {selected_model_name} based on the current filters.")
else:
st.success(f"Top {len(recommendations)} Recommendations from {selected_model_name}:")
for idx, row in recommendations.iterrows():
st.write(f"### {idx + 1}. {row['title']}")
st.write(f"**Category:** {row['categories']}")
st.write(f"**Authors:** {row['authors']}")
st.write(f"**Abstract:** {row['abstract']}")
st.write(f"**Last Updated:** {row['update_date']}")
st.write("---")
explanation = generate_explanation(user_abstract, row['title'], row['authors'], row['abstract'])
explanation = post_process_explanation(explanation)
st.write(f"**Explanation:** {explanation}")
st.write("---") |