Spaces:
Running
Running
Updated app.py
Browse filesAdded tf-idf method for better semantic search
app.py
CHANGED
@@ -49,40 +49,108 @@ from smolagents import CodeAgent, HfApiModel, tool
|
|
49 |
# print(f"ERROR: {str(e)}") # Debug errors
|
50 |
# return [f"Error fetching research papers: {str(e)}"]
|
51 |
|
52 |
-
from rank_bm25 import BM25Okapi
|
53 |
-
import nltk
|
54 |
|
55 |
-
|
56 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
if os.path.exists(nltk_data_path):
|
61 |
-
shutil.rmtree(nltk_data_path) # Remove corrupted version
|
62 |
|
63 |
-
print("✅ Removed old NLTK 'punkt' data. Reinstalling...")
|
64 |
|
65 |
-
#
|
66 |
-
|
|
|
67 |
|
68 |
-
|
|
|
|
|
69 |
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
|
73 |
-
"""Fetches and ranks arXiv papers using
|
74 |
|
75 |
Args:
|
76 |
keywords: List of keywords for search.
|
77 |
num_results: Number of results to return.
|
78 |
|
79 |
Returns:
|
80 |
-
List of the most relevant papers based on
|
81 |
"""
|
82 |
try:
|
83 |
print(f"DEBUG: Searching arXiv papers with keywords: {keywords}")
|
84 |
|
85 |
-
# Use a general keyword search
|
86 |
query = "+AND+".join([f"all:{kw}" for kw in keywords])
|
87 |
query_encoded = urllib.parse.quote(query)
|
88 |
url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
|
@@ -105,17 +173,22 @@ def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
|
|
105 |
if not papers:
|
106 |
return [{"error": "No results found. Try different keywords."}]
|
107 |
|
108 |
-
#
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
|
113 |
-
|
114 |
|
115 |
-
#
|
116 |
-
ranked_papers = sorted(zip(papers,
|
117 |
|
118 |
-
# Return the most relevant
|
119 |
return [paper[0] for paper in ranked_papers[:num_results]]
|
120 |
|
121 |
except Exception as e:
|
@@ -188,11 +261,11 @@ def search_papers(user_input):
|
|
188 |
results = fetch_latest_arxiv_papers(keywords, num_results=3) # Fetch 3 results
|
189 |
print(f"DEBUG: Results received - {results}") # Debug function output
|
190 |
|
191 |
-
#
|
192 |
if isinstance(results, list) and len(results) > 0 and "error" in results[0]:
|
193 |
return results[0]["error"] # Return the error message directly
|
194 |
|
195 |
-
#
|
196 |
if isinstance(results, list) and results and isinstance(results[0], dict):
|
197 |
formatted_results = "\n\n".join([
|
198 |
f"---\n\n"
|
|
|
49 |
# print(f"ERROR: {str(e)}") # Debug errors
|
50 |
# return [f"Error fetching research papers: {str(e)}"]
|
51 |
|
|
|
|
|
52 |
|
53 |
+
#"""------Applied BM25 search for paper retrival------"""
|
54 |
+
# from rank_bm25 import BM25Okapi
|
55 |
+
# import nltk
|
56 |
+
|
57 |
+
# import os
|
58 |
+
# import shutil
|
59 |
+
|
60 |
+
|
61 |
+
# nltk_data_path = os.path.join(nltk.data.path[0], "tokenizers", "punkt")
|
62 |
+
# if os.path.exists(nltk_data_path):
|
63 |
+
# shutil.rmtree(nltk_data_path) # Remove corrupted version
|
64 |
+
|
65 |
+
# print("Removed old NLTK 'punkt' data. Reinstalling...")
|
66 |
|
67 |
+
# # Step 2: Download the correct 'punkt' tokenizer
|
68 |
+
# nltk.download("punkt_tab")
|
69 |
|
70 |
+
# print("Successfully installed 'punkt'!")
|
|
|
|
|
71 |
|
|
|
72 |
|
73 |
+
# @tool # Register the function properly as a SmolAgents tool
|
74 |
+
# def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
|
75 |
+
# """Fetches and ranks arXiv papers using BM25 keyword relevance.
|
76 |
|
77 |
+
# Args:
|
78 |
+
# keywords: List of keywords for search.
|
79 |
+
# num_results: Number of results to return.
|
80 |
|
81 |
+
# Returns:
|
82 |
+
# List of the most relevant papers based on BM25 ranking.
|
83 |
+
# """
|
84 |
+
# try:
|
85 |
+
# print(f"DEBUG: Searching arXiv papers with keywords: {keywords}")
|
86 |
|
87 |
+
# # Use a general keyword search (without `ti:` and `abs:`)
|
88 |
+
# query = "+AND+".join([f"all:{kw}" for kw in keywords])
|
89 |
+
# query_encoded = urllib.parse.quote(query)
|
90 |
+
# url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
|
91 |
+
|
92 |
+
# print(f"DEBUG: Query URL - {url}")
|
93 |
+
|
94 |
+
# feed = feedparser.parse(url)
|
95 |
+
# papers = []
|
96 |
+
|
97 |
+
# # Extract papers from arXiv
|
98 |
+
# for entry in feed.entries:
|
99 |
+
# papers.append({
|
100 |
+
# "title": entry.title,
|
101 |
+
# "authors": ", ".join(author.name for author in entry.authors),
|
102 |
+
# "year": entry.published[:4],
|
103 |
+
# "abstract": entry.summary,
|
104 |
+
# "link": entry.link
|
105 |
+
# })
|
106 |
+
|
107 |
+
# if not papers:
|
108 |
+
# return [{"error": "No results found. Try different keywords."}]
|
109 |
+
|
110 |
+
# # Apply BM25 ranking
|
111 |
+
# tokenized_corpus = [nltk.word_tokenize(paper["title"].lower() + " " + paper["abstract"].lower()) for paper in papers]
|
112 |
+
# bm25 = BM25Okapi(tokenized_corpus)
|
113 |
+
|
114 |
+
# tokenized_query = nltk.word_tokenize(" ".join(keywords).lower())
|
115 |
+
# scores = bm25.get_scores(tokenized_query)
|
116 |
+
|
117 |
+
# # Sort papers based on BM25 score
|
118 |
+
# ranked_papers = sorted(zip(papers, scores), key=lambda x: x[1], reverse=True)
|
119 |
+
|
120 |
+
# # Return the most relevant ones
|
121 |
+
# return [paper[0] for paper in ranked_papers[:num_results]]
|
122 |
+
|
123 |
+
# except Exception as e:
|
124 |
+
# print(f"ERROR: {str(e)}")
|
125 |
+
# return [{"error": f"Error fetching research papers: {str(e)}"}]
|
126 |
+
|
127 |
+
|
128 |
+
"""------Applied TF-IDF for better semantic search------"""
|
129 |
+
import numpy as np
|
130 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
131 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
132 |
+
import gradio as gr
|
133 |
+
from smolagents import CodeAgent, HfApiModel, tool
|
134 |
+
import nltk
|
135 |
+
|
136 |
+
nltk.download("stopwords")
|
137 |
+
from nltk.corpus import stopwords
|
138 |
+
|
139 |
+
@tool # ✅ Register the function properly as a SmolAgents tool
|
140 |
def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
|
141 |
+
"""Fetches and ranks arXiv papers using TF-IDF and Cosine Similarity.
|
142 |
|
143 |
Args:
|
144 |
keywords: List of keywords for search.
|
145 |
num_results: Number of results to return.
|
146 |
|
147 |
Returns:
|
148 |
+
List of the most relevant papers based on TF-IDF ranking.
|
149 |
"""
|
150 |
try:
|
151 |
print(f"DEBUG: Searching arXiv papers with keywords: {keywords}")
|
152 |
|
153 |
+
# Use a general keyword search
|
154 |
query = "+AND+".join([f"all:{kw}" for kw in keywords])
|
155 |
query_encoded = urllib.parse.quote(query)
|
156 |
url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
|
|
|
173 |
if not papers:
|
174 |
return [{"error": "No results found. Try different keywords."}]
|
175 |
|
176 |
+
# Prepare TF-IDF Vectorization
|
177 |
+
corpus = [paper["title"] + " " + paper["abstract"] for paper in papers]
|
178 |
+
vectorizer = TfidfVectorizer(stop_words=stopwords.words('english')) # Remove stopwords
|
179 |
+
tfidf_matrix = vectorizer.fit_transform(corpus)
|
180 |
+
|
181 |
+
# Transform Query into TF-IDF Vector
|
182 |
+
query_str = " ".join(keywords)
|
183 |
+
query_vec = vectorizer.transform([query_str])
|
184 |
|
185 |
+
#Compute Cosine Similarity
|
186 |
+
similarity_scores = cosine_similarity(query_vec, tfidf_matrix).flatten()
|
187 |
|
188 |
+
#Sort papers based on similarity score
|
189 |
+
ranked_papers = sorted(zip(papers, similarity_scores), key=lambda x: x[1], reverse=True)
|
190 |
|
191 |
+
# Return the most relevant papers
|
192 |
return [paper[0] for paper in ranked_papers[:num_results]]
|
193 |
|
194 |
except Exception as e:
|
|
|
261 |
results = fetch_latest_arxiv_papers(keywords, num_results=3) # Fetch 3 results
|
262 |
print(f"DEBUG: Results received - {results}") # Debug function output
|
263 |
|
264 |
+
# Check if the API returned an error
|
265 |
if isinstance(results, list) and len(results) > 0 and "error" in results[0]:
|
266 |
return results[0]["error"] # Return the error message directly
|
267 |
|
268 |
+
# Format results only if valid papers exist
|
269 |
if isinstance(results, list) and results and isinstance(results[0], dict):
|
270 |
formatted_results = "\n\n".join([
|
271 |
f"---\n\n"
|