arpannookala's picture
change
2af09b8
import streamlit as st
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import lancedb
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import time
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# from google_drive_downloader import GoogleDriveDownloader as gdd
# Download NLTK resources if not already downloaded
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
# --------------------------- Dynamic Download of Large Files --------------------------- #
import gdown
import zipfile
import os
# Function to download and extract folder
def download_and_extract_gdrive(file_id, destination, extract_to):
# Download the zip file
gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)
# Extract the zip file
with zipfile.ZipFile(destination, 'r') as zip_ref:
zip_ref.extractall(extract_to)
os.remove(destination) # Clean up the downloaded zip file
# Download and extract LanceDB and fine-tuned model
st.info("Downloading and setting up necessary data. This might take a while...")
download_and_extract_gdrive(
file_id="1Qnb8bs_NXWlhDwGoswOgsp2DiLBMbfSY", # Replace with the actual Google Drive file ID
destination="lancedb_directory_main",
extract_to="./"
)
download_and_extract_gdrive(
file_id="1_9VVuN_P3zsTBYzg0lAeh4ghd9zhXS3w", # Replace with the actual Google Drive file ID
destination="finetuned_all_minilm_l6_v2",
extract_to="./"
)
# # --------------------------- Load the LanceDB Table and Models --------------------------- #
# Connect to LanceDB
DB_PATH = "./lancedb_directory_main"
TABLE_NAME_1 = "enhanced_papers_pretrained_1"
TABLE_NAME_2 = "enhanced_papers_pretrained_2"
TABLE_NAME_3 = "enhanced_papers_finetuned"
db = lancedb.connect(DB_PATH)
table1 = db.open_table(TABLE_NAME_1)
table2 = db.open_table(TABLE_NAME_2)
table3 = db.open_table(TABLE_NAME_3)
# Load the SentenceTransformer models
embedding_models = {
"all-MiniLM-L6-v2": SentenceTransformer('all-MiniLM-L6-v2'),
"allenai-specter": SentenceTransformer('allenai-specter'),
"finetuned_all_minilm_l6_v2": SentenceTransformer('./finetuned_all_minilm_l6_v2')
}
model_tables = {
"all-MiniLM-L6-v2": table1,
"allenai-specter": table2,
"finetuned_all_minilm_l6_v2": table3
}
# Load the tokenizer and summarization model for RAG-based explanations
MODEL_NAME = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
rag_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
rag_pipeline = pipeline("text2text-generation", model=rag_model, tokenizer=tokenizer, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# --------------------------- Streamlit UI Components --------------------------- #
st.title("Research Paper Recommendation System with RAG-based Explanations")
# Initialize stopwords and lemmatizer
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
# Function to clean text
def clean_text(text):
if pd.isnull(text):
return ""
# Lowercasing
text = text.lower()
# Remove special characters and punctuation
text = re.sub(r'[^a-z0-9\s]', '', text)
# Remove extra whitespace and newlines
text = re.sub(r'\s+', ' ', text).strip()
# Tokenize and remove stopwords, then lemmatize
tokens = text.split()
tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
return ' '.join(tokens)
# Input abstract from the user
user_abstract = st.text_area("Enter the abstract of your paper:", height=200)
# Preprocess the user input abstract
user_abstract = clean_text(user_abstract)
# Number of recommendations slider
k = st.slider("Select the number of recommendations (k):", min_value=1, max_value=20, value=5)
# Model selection dropdown
selected_model_name = st.sidebar.selectbox("Select the embedding model:", list(embedding_models.keys()))
# Fetch unique metadata values for filters
def get_unique_values(table, column):
df = table.to_pandas()
return sorted(df[column].dropna().unique())
table = model_tables[selected_model_name]
categories = get_unique_values(table, 'categories')
authors = get_unique_values(table, 'authors')
# Metadata filters
st.sidebar.header("Filter Recommendations by Metadata")
filter_category = st.sidebar.selectbox("Filter by Category (optional):", [""] + categories)
filter_author = st.sidebar.selectbox("Filter by Author (optional):", [""] + authors)
# --------------------------- Helper Functions --------------------------- #
def generate_explanation(user_abstract, recommended_title, recommended_authors, recommended_abstract, max_input_length=512, max_output_length=200):
prompt = (
f"User's Input:\n{user_abstract}\n\n"
f"Recommended Paper:\n"
f"Title: {recommended_title}\n"
f"Authors: {recommended_authors}\n"
f"Abstract: {recommended_abstract}\n\n"
"Explain briefly, how the recommended paper is relevant to the user's input"
)
try:
explanation = rag_pipeline(
prompt,
max_length=max_output_length,
min_length=50,
do_sample=True,
temperature=0.7,
top_p=0.9,
truncation=True
)[0]['generated_text']
return explanation
except Exception as e:
return f"Error during generation: {e}"
def post_process_explanation(text):
sentences = list(dict.fromkeys(text.split('. ')))
return '. '.join(sentences).strip()
def get_recommendations(table, embedding_model, model_name):
with st.spinner(f"Generating embedding for your abstract using {model_name}..."):
user_embedding = embedding_model.encode(user_abstract, convert_to_tensor=True).cpu().numpy()
# Perform similarity search
query = table.search(user_embedding).metric("cosine").limit(k)
if filter_category:
query = query.where(f"categories == '{filter_category}'")
if filter_author:
query = query.where(f"authors LIKE '%{filter_author}%'")
return query.to_pandas()
# --------------------------- Main Logic for Recommendations --------------------------- #
if st.button("Get Recommendations"):
if not user_abstract:
st.error("Please enter an abstract to proceed.")
else:
embedding_model = embedding_models[selected_model_name]
table = model_tables[selected_model_name]
st.header(f"Recommendations using {selected_model_name}")
recommendations = get_recommendations(table, embedding_model, selected_model_name)
if recommendations.empty:
st.warning(f"No recommendations found for {selected_model_name} based on the current filters.")
else:
st.success(f"Top {len(recommendations)} Recommendations from {selected_model_name}:")
for idx, row in recommendations.iterrows():
st.write(f"### {idx + 1}. {row['title']}")
st.write(f"**Category:** {row['categories']}")
st.write(f"**Authors:** {row['authors']}")
st.write(f"**Abstract:** {row['abstract']}")
st.write(f"**Last Updated:** {row['update_date']}")
st.write("---")
explanation = generate_explanation(user_abstract, row['title'], row['authors'], row['abstract'])
explanation = post_process_explanation(explanation)
st.write(f"**Explanation:** {explanation}")
st.write("---")