pdx97 commited on
Commit
8267210
·
verified ·
1 Parent(s): fcfda0a

Updated app.py

Browse files

Added tf-idf method for better semantic search

Files changed (1) hide show
  1. app.py +98 -25
app.py CHANGED
@@ -49,40 +49,108 @@ from smolagents import CodeAgent, HfApiModel, tool
49
  # print(f"ERROR: {str(e)}") # Debug errors
50
  # return [f"Error fetching research papers: {str(e)}"]
51
 
52
- from rank_bm25 import BM25Okapi
53
- import nltk
54
 
55
- import os
56
- import shutil
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
58
 
59
- nltk_data_path = os.path.join(nltk.data.path[0], "tokenizers", "punkt")
60
- if os.path.exists(nltk_data_path):
61
- shutil.rmtree(nltk_data_path) # Remove corrupted version
62
 
63
- print("✅ Removed old NLTK 'punkt' data. Reinstalling...")
64
 
65
- # Step 2: Download the correct 'punkt' tokenizer
66
- nltk.download("punkt_tab")
 
67
 
68
- print("✅ Successfully installed 'punkt'!")
 
 
69
 
 
 
 
 
 
70
 
71
- @tool # Register the function properly as a SmolAgents tool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
73
- """Fetches and ranks arXiv papers using BM25 keyword relevance.
74
 
75
  Args:
76
  keywords: List of keywords for search.
77
  num_results: Number of results to return.
78
 
79
  Returns:
80
- List of the most relevant papers based on BM25 ranking.
81
  """
82
  try:
83
  print(f"DEBUG: Searching arXiv papers with keywords: {keywords}")
84
 
85
- # Use a general keyword search (without `ti:` and `abs:`)
86
  query = "+AND+".join([f"all:{kw}" for kw in keywords])
87
  query_encoded = urllib.parse.quote(query)
88
  url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
@@ -105,17 +173,22 @@ def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
105
  if not papers:
106
  return [{"error": "No results found. Try different keywords."}]
107
 
108
- # Apply BM25 ranking
109
- tokenized_corpus = [nltk.word_tokenize(paper["title"].lower() + " " + paper["abstract"].lower()) for paper in papers]
110
- bm25 = BM25Okapi(tokenized_corpus)
 
 
 
 
 
111
 
112
- tokenized_query = nltk.word_tokenize(" ".join(keywords).lower())
113
- scores = bm25.get_scores(tokenized_query)
114
 
115
- # Sort papers based on BM25 score
116
- ranked_papers = sorted(zip(papers, scores), key=lambda x: x[1], reverse=True)
117
 
118
- # Return the most relevant ones
119
  return [paper[0] for paper in ranked_papers[:num_results]]
120
 
121
  except Exception as e:
@@ -188,11 +261,11 @@ def search_papers(user_input):
188
  results = fetch_latest_arxiv_papers(keywords, num_results=3) # Fetch 3 results
189
  print(f"DEBUG: Results received - {results}") # Debug function output
190
 
191
- # Check if the API returned an error
192
  if isinstance(results, list) and len(results) > 0 and "error" in results[0]:
193
  return results[0]["error"] # Return the error message directly
194
 
195
- # Format results only if valid papers exist
196
  if isinstance(results, list) and results and isinstance(results[0], dict):
197
  formatted_results = "\n\n".join([
198
  f"---\n\n"
 
49
  # print(f"ERROR: {str(e)}") # Debug errors
50
  # return [f"Error fetching research papers: {str(e)}"]
51
 
 
 
52
 
53
+ #"""------Applied BM25 search for paper retrival------"""
54
+ # from rank_bm25 import BM25Okapi
55
+ # import nltk
56
+
57
+ # import os
58
+ # import shutil
59
+
60
+
61
+ # nltk_data_path = os.path.join(nltk.data.path[0], "tokenizers", "punkt")
62
+ # if os.path.exists(nltk_data_path):
63
+ # shutil.rmtree(nltk_data_path) # Remove corrupted version
64
+
65
+ # print("Removed old NLTK 'punkt' data. Reinstalling...")
66
 
67
+ # # Step 2: Download the correct 'punkt' tokenizer
68
+ # nltk.download("punkt_tab")
69
 
70
+ # print("Successfully installed 'punkt'!")
 
 
71
 
 
72
 
73
+ # @tool # Register the function properly as a SmolAgents tool
74
+ # def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
75
+ # """Fetches and ranks arXiv papers using BM25 keyword relevance.
76
 
77
+ # Args:
78
+ # keywords: List of keywords for search.
79
+ # num_results: Number of results to return.
80
 
81
+ # Returns:
82
+ # List of the most relevant papers based on BM25 ranking.
83
+ # """
84
+ # try:
85
+ # print(f"DEBUG: Searching arXiv papers with keywords: {keywords}")
86
 
87
+ # # Use a general keyword search (without `ti:` and `abs:`)
88
+ # query = "+AND+".join([f"all:{kw}" for kw in keywords])
89
+ # query_encoded = urllib.parse.quote(query)
90
+ # url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
91
+
92
+ # print(f"DEBUG: Query URL - {url}")
93
+
94
+ # feed = feedparser.parse(url)
95
+ # papers = []
96
+
97
+ # # Extract papers from arXiv
98
+ # for entry in feed.entries:
99
+ # papers.append({
100
+ # "title": entry.title,
101
+ # "authors": ", ".join(author.name for author in entry.authors),
102
+ # "year": entry.published[:4],
103
+ # "abstract": entry.summary,
104
+ # "link": entry.link
105
+ # })
106
+
107
+ # if not papers:
108
+ # return [{"error": "No results found. Try different keywords."}]
109
+
110
+ # # Apply BM25 ranking
111
+ # tokenized_corpus = [nltk.word_tokenize(paper["title"].lower() + " " + paper["abstract"].lower()) for paper in papers]
112
+ # bm25 = BM25Okapi(tokenized_corpus)
113
+
114
+ # tokenized_query = nltk.word_tokenize(" ".join(keywords).lower())
115
+ # scores = bm25.get_scores(tokenized_query)
116
+
117
+ # # Sort papers based on BM25 score
118
+ # ranked_papers = sorted(zip(papers, scores), key=lambda x: x[1], reverse=True)
119
+
120
+ # # Return the most relevant ones
121
+ # return [paper[0] for paper in ranked_papers[:num_results]]
122
+
123
+ # except Exception as e:
124
+ # print(f"ERROR: {str(e)}")
125
+ # return [{"error": f"Error fetching research papers: {str(e)}"}]
126
+
127
+
128
+ """------Applied TF-IDF for better semantic search------"""
129
+ import numpy as np
130
+ from sklearn.feature_extraction.text import TfidfVectorizer
131
+ from sklearn.metrics.pairwise import cosine_similarity
132
+ import gradio as gr
133
+ from smolagents import CodeAgent, HfApiModel, tool
134
+ import nltk
135
+
136
+ nltk.download("stopwords")
137
+ from nltk.corpus import stopwords
138
+
139
+ @tool # ✅ Register the function properly as a SmolAgents tool
140
  def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
141
+ """Fetches and ranks arXiv papers using TF-IDF and Cosine Similarity.
142
 
143
  Args:
144
  keywords: List of keywords for search.
145
  num_results: Number of results to return.
146
 
147
  Returns:
148
+ List of the most relevant papers based on TF-IDF ranking.
149
  """
150
  try:
151
  print(f"DEBUG: Searching arXiv papers with keywords: {keywords}")
152
 
153
+ # Use a general keyword search
154
  query = "+AND+".join([f"all:{kw}" for kw in keywords])
155
  query_encoded = urllib.parse.quote(query)
156
  url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
 
173
  if not papers:
174
  return [{"error": "No results found. Try different keywords."}]
175
 
176
+ # Prepare TF-IDF Vectorization
177
+ corpus = [paper["title"] + " " + paper["abstract"] for paper in papers]
178
+ vectorizer = TfidfVectorizer(stop_words=stopwords.words('english')) # Remove stopwords
179
+ tfidf_matrix = vectorizer.fit_transform(corpus)
180
+
181
+ # Transform Query into TF-IDF Vector
182
+ query_str = " ".join(keywords)
183
+ query_vec = vectorizer.transform([query_str])
184
 
185
+ #Compute Cosine Similarity
186
+ similarity_scores = cosine_similarity(query_vec, tfidf_matrix).flatten()
187
 
188
+ #Sort papers based on similarity score
189
+ ranked_papers = sorted(zip(papers, similarity_scores), key=lambda x: x[1], reverse=True)
190
 
191
+ # Return the most relevant papers
192
  return [paper[0] for paper in ranked_papers[:num_results]]
193
 
194
  except Exception as e:
 
261
  results = fetch_latest_arxiv_papers(keywords, num_results=3) # Fetch 3 results
262
  print(f"DEBUG: Results received - {results}") # Debug function output
263
 
264
+ # Check if the API returned an error
265
  if isinstance(results, list) and len(results) > 0 and "error" in results[0]:
266
  return results[0]["error"] # Return the error message directly
267
 
268
+ # Format results only if valid papers exist
269
  if isinstance(results, list) and results and isinstance(results[0], dict):
270
  formatted_results = "\n\n".join([
271
  f"---\n\n"