arpannookala's picture
change app add shutil for lance and finetuned
a61b64b
raw
history blame
8.81 kB
import streamlit as st
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import lancedb
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import time
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# from google_drive_downloader import GoogleDriveDownloader as gdd
# Download NLTK resources if not already downloaded
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
# --------------------------- Dynamic Download of Large Files --------------------------- #
import gdown
import zipfile
import os
import shutil
# Function to download and extract folder
def download_and_extract_gdrive_finetuned(file_id, destination, extract_to):
# Download the zip file
gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)
# Extract the zip file
with zipfile.ZipFile(destination, 'r') as zip_ref:
zip_ref.extractall(extract_to)
os.remove(destination) # Clean up the downloaded zip file
def download_and_extract_gdrive(file_id, destination, extract_to):
# Download the zip file
gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)
# Use a temporary directory for extraction
temp_dir = "./temp_extract"
os.makedirs(temp_dir, exist_ok=True)
with zipfile.ZipFile(destination, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# Ensure files are moved correctly
if not os.path.exists(extract_to):
os.makedirs(extract_to, exist_ok=True)
for item in os.listdir(temp_dir):
item_path = os.path.join(temp_dir, item)
shutil.move(item_path, os.path.join(extract_to, item))
# Cleanup
shutil.rmtree(temp_dir)
os.remove(destination)
# Download and extract LanceDB and fine-tuned model
st.info("Downloading and setting up necessary data. This might take a while...")
download_and_extract_gdrive(
file_id="1Qnb8bs_NXWlhDwGoswOgsp2DiLBMbfSY", # Replace with the actual Google Drive file ID
destination="lancedb_directory_main",
extract_to="./"
)
download_and_extract_gdrive_finetuned(
file_id="1_9VVuN_P3zsTBYzg0lAeh4ghd9zhXS3w", # Replace with the actual Google Drive file ID
destination="finetuned_all_minilm_l6_v2",
extract_to="./"
)
# # --------------------------- Load the LanceDB Table and Models --------------------------- #
# Validate extracted files
expected_files = [
"enhanced_papers_pretrained_1.lance",
"enhanced_papers_pretrained_2.lance",
"enhanced_papers_finetuned.lance"
]
for file in expected_files:
file_path = os.path.join("./lancedb_directory_main", file)
if not os.path.isfile(file_path):
raise FileNotFoundError(f"Expected file is missing: {file_path}")
# Connect to LanceDB
DB_PATH = "./lancedb_directory_main"
TABLE_NAME_1 = "enhanced_papers_pretrained_1"
TABLE_NAME_2 = "enhanced_papers_pretrained_2"
TABLE_NAME_3 = "enhanced_papers_finetuned"
os.makedirs(DB_PATH, exist_ok=True)
db = lancedb.connect(DB_PATH)
table1 = db.open_table(TABLE_NAME_1)
table2 = db.open_table(TABLE_NAME_2)
table3 = db.open_table(TABLE_NAME_3)
# Load the SentenceTransformer models
embedding_models = {
"all-MiniLM-L6-v2": SentenceTransformer('all-MiniLM-L6-v2'),
"allenai-specter": SentenceTransformer('allenai-specter'),
"finetuned_all_minilm_l6_v2": SentenceTransformer('./finetuned_all_minilm_l6_v2')
}
model_tables = {
"all-MiniLM-L6-v2": table1,
"allenai-specter": table2,
"finetuned_all_minilm_l6_v2": table3
}
# Load the tokenizer and summarization model for RAG-based explanations
MODEL_NAME = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
rag_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
rag_pipeline = pipeline("text2text-generation", model=rag_model, tokenizer=tokenizer, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# --------------------------- Streamlit UI Components --------------------------- #
st.title("Research Paper Recommendation System with RAG-based Explanations")
# Initialize stopwords and lemmatizer
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
# Function to clean text
def clean_text(text):
if pd.isnull(text):
return ""
# Lowercasing
text = text.lower()
# Remove special characters and punctuation
text = re.sub(r'[^a-z0-9\s]', '', text)
# Remove extra whitespace and newlines
text = re.sub(r'\s+', ' ', text).strip()
# Tokenize and remove stopwords, then lemmatize
tokens = text.split()
tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
return ' '.join(tokens)
# Input abstract from the user
user_abstract = st.text_area("Enter the abstract of your paper:", height=200)
# Preprocess the user input abstract
user_abstract = clean_text(user_abstract)
# Number of recommendations slider
k = st.slider("Select the number of recommendations (k):", min_value=1, max_value=20, value=5)
# Model selection dropdown
selected_model_name = st.sidebar.selectbox("Select the embedding model:", list(embedding_models.keys()))
# Fetch unique metadata values for filters
def get_unique_values(table, column):
df = table.to_pandas()
return sorted(df[column].dropna().unique())
table = model_tables[selected_model_name]
categories = get_unique_values(table, 'categories')
authors = get_unique_values(table, 'authors')
# Metadata filters
st.sidebar.header("Filter Recommendations by Metadata")
filter_category = st.sidebar.selectbox("Filter by Category (optional):", [""] + categories)
filter_author = st.sidebar.selectbox("Filter by Author (optional):", [""] + authors)
# --------------------------- Helper Functions --------------------------- #
def generate_explanation(user_abstract, recommended_title, recommended_authors, recommended_abstract, max_input_length=512, max_output_length=200):
prompt = (
f"User's Input:\n{user_abstract}\n\n"
f"Recommended Paper:\n"
f"Title: {recommended_title}\n"
f"Authors: {recommended_authors}\n"
f"Abstract: {recommended_abstract}\n\n"
"Explain briefly, how the recommended paper is relevant to the user's input"
)
try:
explanation = rag_pipeline(
prompt,
max_length=max_output_length,
min_length=50,
do_sample=True,
temperature=0.7,
top_p=0.9,
truncation=True
)[0]['generated_text']
return explanation
except Exception as e:
return f"Error during generation: {e}"
def post_process_explanation(text):
sentences = list(dict.fromkeys(text.split('. ')))
return '. '.join(sentences).strip()
def get_recommendations(table, embedding_model, model_name):
with st.spinner(f"Generating embedding for your abstract using {model_name}..."):
user_embedding = embedding_model.encode(user_abstract, convert_to_tensor=True).cpu().numpy()
# Perform similarity search
query = table.search(user_embedding).metric("cosine").limit(k)
if filter_category:
query = query.where(f"categories == '{filter_category}'")
if filter_author:
query = query.where(f"authors LIKE '%{filter_author}%'")
return query.to_pandas()
# --------------------------- Main Logic for Recommendations --------------------------- #
if st.button("Get Recommendations"):
if not user_abstract:
st.error("Please enter an abstract to proceed.")
else:
embedding_model = embedding_models[selected_model_name]
table = model_tables[selected_model_name]
st.header(f"Recommendations using {selected_model_name}")
recommendations = get_recommendations(table, embedding_model, selected_model_name)
if recommendations.empty:
st.warning(f"No recommendations found for {selected_model_name} based on the current filters.")
else:
st.success(f"Top {len(recommendations)} Recommendations from {selected_model_name}:")
for idx, row in recommendations.iterrows():
st.write(f"### {idx + 1}. {row['title']}")
st.write(f"**Category:** {row['categories']}")
st.write(f"**Authors:** {row['authors']}")
st.write(f"**Abstract:** {row['abstract']}")
st.write(f"**Last Updated:** {row['update_date']}")
st.write("---")
explanation = generate_explanation(user_abstract, row['title'], row['authors'], row['abstract'])
explanation = post_process_explanation(explanation)
st.write(f"**Explanation:** {explanation}")
st.write("---")