File size: 7,618 Bytes
15761f9 e9b803a 15761f9 e9b803a e610b64 6db8e2c e610b64 6db8e2c a61b64b 0909f8b e610b64 a61b64b 2af09b8 a61b64b e610b64 1c201ca 6db8e2c a61b64b 2af09b8 a61b64b 6db8e2c 15761f9 6985019 15761f9 6db8e2c 15761f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import streamlit as st
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import lancedb
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import time
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# from google_drive_downloader import GoogleDriveDownloader as gdd
# Download NLTK resources if not already downloaded
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
# --------------------------- Dynamic Download of Large Files --------------------------- #
import gdown
import zipfile
import os
# Function to download and extract folder
def download_and_extract_gdrive(file_id, destination, extract_to):
# Download the zip file
gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)
# Extract the zip file
with zipfile.ZipFile(destination, 'r') as zip_ref:
zip_ref.extractall(extract_to)
os.remove(destination) # Clean up the downloaded zip file
# Download and extract LanceDB and fine-tuned model
st.info("Downloading and setting up necessary data. This might take a while...")
download_and_extract_gdrive(
file_id="1Qnb8bs_NXWlhDwGoswOgsp2DiLBMbfSY", # Replace with the actual Google Drive file ID
destination="lancedb_directory_main",
extract_to="./"
)
download_and_extract_gdrive(
file_id="1_9VVuN_P3zsTBYzg0lAeh4ghd9zhXS3w", # Replace with the actual Google Drive file ID
destination="finetuned_all_minilm_l6_v2",
extract_to="./"
)
# # --------------------------- Load the LanceDB Table and Models --------------------------- #
# Connect to LanceDB
DB_PATH = "./lancedb_directory_main"
TABLE_NAME_1 = "enhanced_papers_pretrained_1"
TABLE_NAME_2 = "enhanced_papers_pretrained_2"
TABLE_NAME_3 = "enhanced_papers_finetuned"
db = lancedb.connect(DB_PATH)
table1 = db.open_table(TABLE_NAME_1)
table2 = db.open_table(TABLE_NAME_2)
table3 = db.open_table(TABLE_NAME_3)
# Load the SentenceTransformer models
embedding_models = {
"all-MiniLM-L6-v2": SentenceTransformer('all-MiniLM-L6-v2'),
"allenai-specter": SentenceTransformer('allenai-specter'),
"finetuned_all_minilm_l6_v2": SentenceTransformer('./finetuned_all_minilm_l6_v2')
}
model_tables = {
"all-MiniLM-L6-v2": table1,
"allenai-specter": table2,
"finetuned_all_minilm_l6_v2": table3
}
# Load the tokenizer and summarization model for RAG-based explanations
MODEL_NAME = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
rag_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
rag_pipeline = pipeline("text2text-generation", model=rag_model, tokenizer=tokenizer, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# --------------------------- Streamlit UI Components --------------------------- #
st.title("Research Paper Recommendation System with RAG-based Explanations")
# Initialize stopwords and lemmatizer
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
# Function to clean text
def clean_text(text):
if pd.isnull(text):
return ""
# Lowercasing
text = text.lower()
# Remove special characters and punctuation
text = re.sub(r'[^a-z0-9\s]', '', text)
# Remove extra whitespace and newlines
text = re.sub(r'\s+', ' ', text).strip()
# Tokenize and remove stopwords, then lemmatize
tokens = text.split()
tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
return ' '.join(tokens)
# Input abstract from the user
user_abstract = st.text_area("Enter the abstract of your paper:", height=200)
# Preprocess the user input abstract
user_abstract = clean_text(user_abstract)
# Number of recommendations slider
k = st.slider("Select the number of recommendations (k):", min_value=1, max_value=20, value=5)
# Model selection dropdown
selected_model_name = st.sidebar.selectbox("Select the embedding model:", list(embedding_models.keys()))
# Fetch unique metadata values for filters
def get_unique_values(table, column):
df = table.to_pandas()
return sorted(df[column].dropna().unique())
table = model_tables[selected_model_name]
categories = get_unique_values(table, 'categories')
authors = get_unique_values(table, 'authors')
# Metadata filters
st.sidebar.header("Filter Recommendations by Metadata")
filter_category = st.sidebar.selectbox("Filter by Category (optional):", [""] + categories)
filter_author = st.sidebar.selectbox("Filter by Author (optional):", [""] + authors)
# --------------------------- Helper Functions --------------------------- #
def generate_explanation(user_abstract, recommended_title, recommended_authors, recommended_abstract, max_input_length=512, max_output_length=200):
prompt = (
f"User's Input:\n{user_abstract}\n\n"
f"Recommended Paper:\n"
f"Title: {recommended_title}\n"
f"Authors: {recommended_authors}\n"
f"Abstract: {recommended_abstract}\n\n"
"Explain briefly, how the recommended paper is relevant to the user's input"
)
try:
explanation = rag_pipeline(
prompt,
max_length=max_output_length,
min_length=50,
do_sample=True,
temperature=0.7,
top_p=0.9,
truncation=True
)[0]['generated_text']
return explanation
except Exception as e:
return f"Error during generation: {e}"
def post_process_explanation(text):
sentences = list(dict.fromkeys(text.split('. ')))
return '. '.join(sentences).strip()
def get_recommendations(table, embedding_model, model_name):
with st.spinner(f"Generating embedding for your abstract using {model_name}..."):
user_embedding = embedding_model.encode(user_abstract, convert_to_tensor=True).cpu().numpy()
# Perform similarity search
query = table.search(user_embedding).metric("cosine").limit(k)
if filter_category:
query = query.where(f"categories == '{filter_category}'")
if filter_author:
query = query.where(f"authors LIKE '%{filter_author}%'")
return query.to_pandas()
# --------------------------- Main Logic for Recommendations --------------------------- #
if st.button("Get Recommendations"):
if not user_abstract:
st.error("Please enter an abstract to proceed.")
else:
embedding_model = embedding_models[selected_model_name]
table = model_tables[selected_model_name]
st.header(f"Recommendations using {selected_model_name}")
recommendations = get_recommendations(table, embedding_model, selected_model_name)
if recommendations.empty:
st.warning(f"No recommendations found for {selected_model_name} based on the current filters.")
else:
st.success(f"Top {len(recommendations)} Recommendations from {selected_model_name}:")
for idx, row in recommendations.iterrows():
st.write(f"### {idx + 1}. {row['title']}")
st.write(f"**Category:** {row['categories']}")
st.write(f"**Authors:** {row['authors']}")
st.write(f"**Abstract:** {row['abstract']}")
st.write(f"**Last Updated:** {row['update_date']}")
st.write("---")
explanation = generate_explanation(user_abstract, row['title'], row['authors'], row['abstract'])
explanation = post_process_explanation(explanation)
st.write(f"**Explanation:** {explanation}")
st.write("---") |