Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -303,7 +303,8 @@
|
|
303 |
|
304 |
# # Launch Gradio App
|
305 |
# demo.launch()
|
306 |
-
|
|
|
307 |
import feedparser
|
308 |
import urllib.parse
|
309 |
import yaml
|
@@ -314,13 +315,14 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
314 |
import gradio as gr
|
315 |
from smolagents import CodeAgent, HfApiModel, tool
|
316 |
import nltk
|
|
|
317 |
|
|
|
318 |
nltk.download("stopwords")
|
319 |
nltk.download("punkt")
|
320 |
from nltk.corpus import stopwords
|
321 |
-
from transformers import pipeline
|
322 |
|
323 |
-
# GPT Summarization Pipeline
|
324 |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
325 |
|
326 |
@tool
|
@@ -339,63 +341,100 @@ def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
|
|
339 |
- "year" (str): The year of publication.
|
340 |
- "abstract" (str): A summary of the paper.
|
341 |
- "link" (str): A URL to the full paper.
|
|
|
|
|
342 |
"""
|
343 |
try:
|
344 |
-
|
|
|
345 |
query_encoded = urllib.parse.quote(query)
|
346 |
url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
|
347 |
|
|
|
348 |
feed = feedparser.parse(url)
|
349 |
papers = []
|
350 |
-
|
|
|
351 |
for entry in feed.entries:
|
352 |
-
|
353 |
"title": entry.title,
|
354 |
"authors": ", ".join(author.name for author in entry.authors),
|
355 |
"year": entry.published[:4],
|
356 |
"abstract": entry.summary,
|
357 |
-
"link": entry.link
|
358 |
-
}
|
359 |
-
|
|
|
|
|
360 |
if not papers:
|
361 |
return [{"error": "No results found. Try different keywords."}]
|
362 |
|
|
|
363 |
corpus = [paper["title"] + " " + paper["abstract"] for paper in papers]
|
364 |
-
vectorizer = TfidfVectorizer(stop_words=stopwords.words('english'))
|
365 |
tfidf_matrix = vectorizer.fit_transform(corpus)
|
366 |
|
|
|
367 |
query_str = " ".join(keywords)
|
368 |
query_vec = vectorizer.transform([query_str])
|
369 |
similarity_scores = cosine_similarity(query_vec, tfidf_matrix).flatten()
|
370 |
|
|
|
371 |
ranked_papers = sorted(zip(papers, similarity_scores), key=lambda x: x[1], reverse=True)
|
372 |
-
|
373 |
-
|
|
|
|
|
374 |
paper["summary"] = summarizer(paper["abstract"], max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
|
375 |
-
|
376 |
return [paper for paper, _ in ranked_papers[:num_results]]
|
377 |
|
378 |
except Exception as e:
|
379 |
return [{"error": f"Error fetching research papers: {str(e)}"}]
|
380 |
|
|
|
381 |
@tool
|
382 |
def get_citation_count(paper_title: str) -> int:
|
383 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
try:
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
return 0
|
390 |
|
|
|
391 |
@tool
|
392 |
def rank_papers_by_citations(papers: list) -> list:
|
393 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
for paper in papers:
|
395 |
paper["citations"] = get_citation_count(paper["title"])
|
396 |
-
return sorted(papers, key=lambda x: (x["citations"], x
|
|
|
397 |
|
398 |
-
# AI Model
|
399 |
model = HfApiModel(
|
400 |
max_tokens=2096,
|
401 |
temperature=0.5,
|
@@ -403,11 +442,11 @@ model = HfApiModel(
|
|
403 |
custom_role_conversions=None,
|
404 |
)
|
405 |
|
406 |
-
# Load prompt templates
|
407 |
with open("prompts.yaml", 'r') as stream:
|
408 |
prompt_templates = yaml.safe_load(stream)
|
409 |
|
410 |
-
# Create the AI Agent
|
411 |
agent = CodeAgent(
|
412 |
model=model,
|
413 |
tools=[fetch_latest_arxiv_papers, get_citation_count, rank_papers_by_citations],
|
@@ -420,7 +459,8 @@ agent = CodeAgent(
|
|
420 |
prompt_templates=prompt_templates
|
421 |
)
|
422 |
|
423 |
-
|
|
|
424 |
with gr.Blocks() as demo:
|
425 |
gr.Markdown("# ScholarAgent")
|
426 |
keyword_input = gr.Textbox(label="Enter keywords or full sentences", placeholder="e.g., deep learning, reinforcement learning")
|
@@ -430,15 +470,17 @@ with gr.Blocks() as demo:
|
|
430 |
def search_papers(user_input):
|
431 |
keywords = [kw.strip() for kw in user_input.split(",") if kw.strip()]
|
432 |
results = fetch_latest_arxiv_papers(keywords, num_results=3)
|
433 |
-
|
|
|
434 |
return results[0]["error"]
|
|
|
435 |
return "\n\n".join([
|
436 |
f"---\n\n"
|
437 |
f"📌 **Title:** {paper['title']}\n\n"
|
438 |
f"👨🔬 **Authors:** {paper['authors']}\n\n"
|
439 |
f"📅 **Year:** {paper['year']}\n\n"
|
440 |
-
f"📖 **Summary:** {paper
|
441 |
-
f"🔢 **Citations:** {paper
|
442 |
f"[🔗 Read Full Paper]({paper['link']})\n\n"
|
443 |
for paper in results
|
444 |
])
|
@@ -446,6 +488,7 @@ with gr.Blocks() as demo:
|
|
446 |
search_button.click(search_papers, inputs=[keyword_input], outputs=[output_display])
|
447 |
print("DEBUG: Gradio UI is running. Waiting for user input...")
|
448 |
|
449 |
-
# Launch Gradio App
|
450 |
demo.launch()
|
451 |
|
|
|
|
303 |
|
304 |
# # Launch Gradio App
|
305 |
# demo.launch()
|
306 |
+
|
307 |
+
"""------Enhanced ScholarAgent with Fixes and Features-----"""
|
308 |
import feedparser
|
309 |
import urllib.parse
|
310 |
import yaml
|
|
|
315 |
import gradio as gr
|
316 |
from smolagents import CodeAgent, HfApiModel, tool
|
317 |
import nltk
|
318 |
+
from transformers import pipeline
|
319 |
|
320 |
+
# ✅ Ensure necessary NLTK data is downloaded
|
321 |
nltk.download("stopwords")
|
322 |
nltk.download("punkt")
|
323 |
from nltk.corpus import stopwords
|
|
|
324 |
|
325 |
+
# ✅ GPT Summarization Pipeline
|
326 |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
327 |
|
328 |
@tool
|
|
|
341 |
- "year" (str): The year of publication.
|
342 |
- "abstract" (str): A summary of the paper.
|
343 |
- "link" (str): A URL to the full paper.
|
344 |
+
- "citations" (int): Number of citations (from Semantic Scholar).
|
345 |
+
- "summary" (str): A GPT-generated summary of the abstract.
|
346 |
"""
|
347 |
try:
|
348 |
+
# ✅ Construct the query for ArXiv API
|
349 |
+
query = "+AND+".join([f"all:{kw}" for kw in keywords])
|
350 |
query_encoded = urllib.parse.quote(query)
|
351 |
url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
|
352 |
|
353 |
+
# ✅ Fetch papers from ArXiv
|
354 |
feed = feedparser.parse(url)
|
355 |
papers = []
|
356 |
+
|
357 |
+
# ✅ Extract papers
|
358 |
for entry in feed.entries:
|
359 |
+
paper = {
|
360 |
"title": entry.title,
|
361 |
"authors": ", ".join(author.name for author in entry.authors),
|
362 |
"year": entry.published[:4],
|
363 |
"abstract": entry.summary,
|
364 |
+
"link": entry.link,
|
365 |
+
}
|
366 |
+
paper["citations"] = get_citation_count(paper["title"]) # ✅ Fetch citation count
|
367 |
+
papers.append(paper)
|
368 |
+
|
369 |
if not papers:
|
370 |
return [{"error": "No results found. Try different keywords."}]
|
371 |
|
372 |
+
# ✅ TF-IDF Vectorization
|
373 |
corpus = [paper["title"] + " " + paper["abstract"] for paper in papers]
|
374 |
+
vectorizer = TfidfVectorizer(stop_words=stopwords.words('english'))
|
375 |
tfidf_matrix = vectorizer.fit_transform(corpus)
|
376 |
|
377 |
+
# ✅ Transform Query into TF-IDF Vector
|
378 |
query_str = " ".join(keywords)
|
379 |
query_vec = vectorizer.transform([query_str])
|
380 |
similarity_scores = cosine_similarity(query_vec, tfidf_matrix).flatten()
|
381 |
|
382 |
+
# ✅ Sort papers based on similarity score
|
383 |
ranked_papers = sorted(zip(papers, similarity_scores), key=lambda x: x[1], reverse=True)
|
384 |
+
|
385 |
+
# ✅ Assign TF-IDF scores and generate summaries
|
386 |
+
for paper, score in ranked_papers:
|
387 |
+
paper["tfidf_score"] = score
|
388 |
paper["summary"] = summarizer(paper["abstract"], max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
|
389 |
+
|
390 |
return [paper for paper, _ in ranked_papers[:num_results]]
|
391 |
|
392 |
except Exception as e:
|
393 |
return [{"error": f"Error fetching research papers: {str(e)}"}]
|
394 |
|
395 |
+
|
396 |
@tool
|
397 |
def get_citation_count(paper_title: str) -> int:
|
398 |
+
"""
|
399 |
+
Fetches citation count from Semantic Scholar API.
|
400 |
+
|
401 |
+
Args:
|
402 |
+
paper_title (str): Title of the research paper.
|
403 |
+
|
404 |
+
Returns:
|
405 |
+
int: Citation count (default 0 if not found).
|
406 |
+
"""
|
407 |
try:
|
408 |
+
base_url = "https://api.semanticscholar.org/graph/v1/paper/search"
|
409 |
+
params = {"query": paper_title, "fields": "citationCount"}
|
410 |
+
response = requests.get(base_url, params=params).json()
|
411 |
+
|
412 |
+
if "data" in response and response["data"]:
|
413 |
+
return response["data"][0].get("citationCount", 0)
|
414 |
+
return 0 # Default to 0 if no data found
|
415 |
+
|
416 |
+
except Exception as e:
|
417 |
+
print(f"ERROR fetching citation count: {e}")
|
418 |
return 0
|
419 |
|
420 |
+
|
421 |
@tool
|
422 |
def rank_papers_by_citations(papers: list) -> list:
|
423 |
+
"""
|
424 |
+
Ranks papers based on citation count and TF-IDF similarity.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
papers (list): List of research papers.
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
list: Papers sorted by citation count and TF-IDF score.
|
431 |
+
"""
|
432 |
for paper in papers:
|
433 |
paper["citations"] = get_citation_count(paper["title"])
|
434 |
+
return sorted(papers, key=lambda x: (x["citations"], x.get("tfidf_score", 0)), reverse=True)
|
435 |
+
|
436 |
|
437 |
+
# ✅ AI Model
|
438 |
model = HfApiModel(
|
439 |
max_tokens=2096,
|
440 |
temperature=0.5,
|
|
|
442 |
custom_role_conversions=None,
|
443 |
)
|
444 |
|
445 |
+
# ✅ Load prompt templates
|
446 |
with open("prompts.yaml", 'r') as stream:
|
447 |
prompt_templates = yaml.safe_load(stream)
|
448 |
|
449 |
+
# ✅ Create the AI Agent
|
450 |
agent = CodeAgent(
|
451 |
model=model,
|
452 |
tools=[fetch_latest_arxiv_papers, get_citation_count, rank_papers_by_citations],
|
|
|
459 |
prompt_templates=prompt_templates
|
460 |
)
|
461 |
|
462 |
+
|
463 |
+
# ✅ Gradio UI
|
464 |
with gr.Blocks() as demo:
|
465 |
gr.Markdown("# ScholarAgent")
|
466 |
keyword_input = gr.Textbox(label="Enter keywords or full sentences", placeholder="e.g., deep learning, reinforcement learning")
|
|
|
470 |
def search_papers(user_input):
|
471 |
keywords = [kw.strip() for kw in user_input.split(",") if kw.strip()]
|
472 |
results = fetch_latest_arxiv_papers(keywords, num_results=3)
|
473 |
+
|
474 |
+
if isinstance(results, list) and results and "error" in results[0]:
|
475 |
return results[0]["error"]
|
476 |
+
|
477 |
return "\n\n".join([
|
478 |
f"---\n\n"
|
479 |
f"📌 **Title:** {paper['title']}\n\n"
|
480 |
f"👨🔬 **Authors:** {paper['authors']}\n\n"
|
481 |
f"📅 **Year:** {paper['year']}\n\n"
|
482 |
+
f"📖 **Summary:** {paper.get('summary', 'No summary available')[:500]}... *(truncated)*\n\n"
|
483 |
+
f"🔢 **Citations:** {paper.get('citations', 0)}\n\n"
|
484 |
f"[🔗 Read Full Paper]({paper['link']})\n\n"
|
485 |
for paper in results
|
486 |
])
|
|
|
488 |
search_button.click(search_papers, inputs=[keyword_input], outputs=[output_display])
|
489 |
print("DEBUG: Gradio UI is running. Waiting for user input...")
|
490 |
|
491 |
+
# ✅ Launch Gradio App
|
492 |
demo.launch()
|
493 |
|
494 |
+
|