import os import streamlit as st import faiss import pandas as pd import numpy as np import requests import torch import re import time from sentence_transformers import SentenceTransformer, CrossEncoder from huggingface_hub import hf_hub_download from langdetect import detect HF_REPO = "AKKI-AFK/deepshelf-data" books_file = hf_hub_download(repo_id=HF_REPO, filename="booksummaries.txt", repo_type="dataset") faiss_file = hf_hub_download(repo_id=HF_REPO, filename="faiss_index.bin", repo_type="dataset") df = pd.read_csv(books_file, delimiter="\t") index = faiss.read_index(faiss_file) encoder = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2") cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2") request_times = [] # Track request timestamps @st.cache_data def recommend_books(query): query = sanitize_input(query) if len(query) > 200: st.warning("⚠️ Query is too long. Please keep it under 200 characters.") return [] if len(query) < 3: st.warning("⚠️ Query is too short. Please provide more details.") return [] try: lang = detect(query) if lang != "en": st.warning("⚠️ Non-English query detected. Results may not be accurate.") except: st.warning("⚠️ Could not detect language. Ensure proper input.") search_vector = encoder.encode(query) search_vector = np.array([search_vector]) faiss.normalize_L2(search_vector) distances, ann = index.search(search_vector, k=50) results = pd.DataFrame({'distances': distances[0], 'ann': ann[0]}) merge = pd.merge(results, df, left_on='ann', right_index=True) merge['Query'] = query pairs = list(zip(merge['Query'], merge['summary'])) scores = cross_encoder.predict(pairs) # Running synchronously merge['score'] = scores df_sorted = merge.iloc[merge["score"].argsort()][::-1] return df_sorted[["title", "summary"]][:5].to_dict(orient="records") def sanitize_input(text): """Sanitize input by removing special characters and excessive spaces.""" text = re.sub(r'[^\w\s]', '', text) # Remove special characters text = re.sub(r'\s+', ' ', text).strip() # Normalize spaces return text def rate_limit(): """Rate-limiting function to prevent excessive queries.""" global request_times current_time = time.time() request_times = [t for t in request_times if current_time - t < 10] # Keep only recent requests within 10 seconds if len(request_times) >= 5: st.error("⚠️ Too many requests. Please wait a few seconds before trying again.") return False request_times.append(current_time) return True st.set_page_config(page_title="DeepShelf", page_icon="📚", layout="wide") st.markdown(""" """, unsafe_allow_html=True) st.markdown('