|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
from sentence_transformers import SentenceTransformer |
|
import lancedb |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
import time |
|
import re |
|
import nltk |
|
from nltk.corpus import stopwords |
|
from nltk.stem import WordNetLemmatizer |
|
|
|
|
|
nltk.download('stopwords') |
|
nltk.download('wordnet') |
|
nltk.download('omw-1.4') |
|
|
|
|
|
|
|
import gdown |
|
import zipfile |
|
import os |
|
|
|
|
|
def download_and_extract_gdrive(file_id, destination, extract_to): |
|
|
|
gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False) |
|
|
|
|
|
with zipfile.ZipFile(destination, 'r') as zip_ref: |
|
zip_ref.extractall(extract_to) |
|
os.remove(destination) |
|
|
|
|
|
st.info("Downloading and setting up necessary data. This might take a while...") |
|
|
|
download_and_extract_gdrive( |
|
file_id="1Qnb8bs_NXWlhDwGoswOgsp2DiLBMbfSY", |
|
destination="lancedb_directory_main", |
|
extract_to="./" |
|
) |
|
|
|
download_and_extract_gdrive( |
|
file_id="1_9VVuN_P3zsTBYzg0lAeh4ghd9zhXS3w", |
|
destination="finetuned_all_minilm_l6_v2", |
|
extract_to="./" |
|
) |
|
|
|
|
|
|
|
DB_PATH = "./lancedb_directory_main" |
|
TABLE_NAME_1 = "enhanced_papers_pretrained_1" |
|
TABLE_NAME_2 = "enhanced_papers_pretrained_2" |
|
TABLE_NAME_3 = "enhanced_papers_finetuned" |
|
|
|
db = lancedb.connect(DB_PATH) |
|
table1 = db.open_table(TABLE_NAME_1) |
|
table2 = db.open_table(TABLE_NAME_2) |
|
table3 = db.open_table(TABLE_NAME_3) |
|
|
|
|
|
embedding_models = { |
|
"all-MiniLM-L6-v2": SentenceTransformer('all-MiniLM-L6-v2'), |
|
"allenai-specter": SentenceTransformer('allenai-specter'), |
|
"finetuned_all_minilm_l6_v2": SentenceTransformer('./finetuned_all_minilm_l6_v2') |
|
} |
|
|
|
model_tables = { |
|
"all-MiniLM-L6-v2": table1, |
|
"allenai-specter": table2, |
|
"finetuned_all_minilm_l6_v2": table3 |
|
} |
|
|
|
|
|
MODEL_NAME = "google/flan-t5-large" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
rag_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
|
rag_pipeline = pipeline("text2text-generation", model=rag_model, tokenizer=tokenizer, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) |
|
|
|
|
|
|
|
st.title("Research Paper Recommendation System with RAG-based Explanations") |
|
|
|
|
|
stop_words = set(stopwords.words('english')) |
|
lemmatizer = WordNetLemmatizer() |
|
|
|
|
|
def clean_text(text): |
|
if pd.isnull(text): |
|
return "" |
|
|
|
text = text.lower() |
|
|
|
text = re.sub(r'[^a-z0-9\s]', '', text) |
|
|
|
text = re.sub(r'\s+', ' ', text).strip() |
|
|
|
tokens = text.split() |
|
tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words] |
|
return ' '.join(tokens) |
|
|
|
|
|
user_abstract = st.text_area("Enter the abstract of your paper:", height=200) |
|
|
|
|
|
user_abstract = clean_text(user_abstract) |
|
|
|
|
|
k = st.slider("Select the number of recommendations (k):", min_value=1, max_value=20, value=5) |
|
|
|
|
|
selected_model_name = st.sidebar.selectbox("Select the embedding model:", list(embedding_models.keys())) |
|
|
|
|
|
def get_unique_values(table, column): |
|
df = table.to_pandas() |
|
return sorted(df[column].dropna().unique()) |
|
|
|
table = model_tables[selected_model_name] |
|
categories = get_unique_values(table, 'categories') |
|
authors = get_unique_values(table, 'authors') |
|
|
|
|
|
st.sidebar.header("Filter Recommendations by Metadata") |
|
filter_category = st.sidebar.selectbox("Filter by Category (optional):", [""] + categories) |
|
filter_author = st.sidebar.selectbox("Filter by Author (optional):", [""] + authors) |
|
|
|
|
|
|
|
def generate_explanation(user_abstract, recommended_title, recommended_authors, recommended_abstract, max_input_length=512, max_output_length=200): |
|
prompt = ( |
|
f"User's Input:\n{user_abstract}\n\n" |
|
f"Recommended Paper:\n" |
|
f"Title: {recommended_title}\n" |
|
f"Authors: {recommended_authors}\n" |
|
f"Abstract: {recommended_abstract}\n\n" |
|
"Explain briefly, how the recommended paper is relevant to the user's input" |
|
) |
|
try: |
|
explanation = rag_pipeline( |
|
prompt, |
|
max_length=max_output_length, |
|
min_length=50, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
truncation=True |
|
)[0]['generated_text'] |
|
return explanation |
|
except Exception as e: |
|
return f"Error during generation: {e}" |
|
|
|
def post_process_explanation(text): |
|
sentences = list(dict.fromkeys(text.split('. '))) |
|
return '. '.join(sentences).strip() |
|
|
|
def get_recommendations(table, embedding_model, model_name): |
|
with st.spinner(f"Generating embedding for your abstract using {model_name}..."): |
|
user_embedding = embedding_model.encode(user_abstract, convert_to_tensor=True).cpu().numpy() |
|
|
|
|
|
query = table.search(user_embedding).metric("cosine").limit(k) |
|
|
|
if filter_category: |
|
query = query.where(f"categories == '{filter_category}'") |
|
if filter_author: |
|
query = query.where(f"authors LIKE '%{filter_author}%'") |
|
|
|
return query.to_pandas() |
|
|
|
|
|
|
|
if st.button("Get Recommendations"): |
|
if not user_abstract: |
|
st.error("Please enter an abstract to proceed.") |
|
else: |
|
embedding_model = embedding_models[selected_model_name] |
|
table = model_tables[selected_model_name] |
|
|
|
st.header(f"Recommendations using {selected_model_name}") |
|
recommendations = get_recommendations(table, embedding_model, selected_model_name) |
|
|
|
if recommendations.empty: |
|
st.warning(f"No recommendations found for {selected_model_name} based on the current filters.") |
|
else: |
|
st.success(f"Top {len(recommendations)} Recommendations from {selected_model_name}:") |
|
|
|
for idx, row in recommendations.iterrows(): |
|
st.write(f"### {idx + 1}. {row['title']}") |
|
st.write(f"**Category:** {row['categories']}") |
|
st.write(f"**Authors:** {row['authors']}") |
|
st.write(f"**Abstract:** {row['abstract']}") |
|
st.write(f"**Last Updated:** {row['update_date']}") |
|
st.write("---") |
|
|
|
explanation = generate_explanation(user_abstract, row['title'], row['authors'], row['abstract']) |
|
explanation = post_process_explanation(explanation) |
|
st.write(f"**Explanation:** {explanation}") |
|
st.write("---") |