pdx97 commited on
Commit
47b1f89
·
verified ·
1 Parent(s): 32f1a74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -28
app.py CHANGED
@@ -303,7 +303,8 @@
303
 
304
  # # Launch Gradio App
305
  # demo.launch()
306
- """------New Features-----"""
 
307
  import feedparser
308
  import urllib.parse
309
  import yaml
@@ -314,13 +315,14 @@ from sklearn.metrics.pairwise import cosine_similarity
314
  import gradio as gr
315
  from smolagents import CodeAgent, HfApiModel, tool
316
  import nltk
 
317
 
 
318
  nltk.download("stopwords")
319
  nltk.download("punkt")
320
  from nltk.corpus import stopwords
321
- from transformers import pipeline
322
 
323
- # GPT Summarization Pipeline
324
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
325
 
326
  @tool
@@ -339,63 +341,100 @@ def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
339
  - "year" (str): The year of publication.
340
  - "abstract" (str): A summary of the paper.
341
  - "link" (str): A URL to the full paper.
 
 
342
  """
343
  try:
344
- query = "+AND+".join([f"all:{kw}" for kw in keywords])
 
345
  query_encoded = urllib.parse.quote(query)
346
  url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
347
 
 
348
  feed = feedparser.parse(url)
349
  papers = []
350
-
 
351
  for entry in feed.entries:
352
- papers.append({
353
  "title": entry.title,
354
  "authors": ", ".join(author.name for author in entry.authors),
355
  "year": entry.published[:4],
356
  "abstract": entry.summary,
357
- "link": entry.link
358
- })
359
-
 
 
360
  if not papers:
361
  return [{"error": "No results found. Try different keywords."}]
362
 
 
363
  corpus = [paper["title"] + " " + paper["abstract"] for paper in papers]
364
- vectorizer = TfidfVectorizer(stop_words=stopwords.words('english'))
365
  tfidf_matrix = vectorizer.fit_transform(corpus)
366
 
 
367
  query_str = " ".join(keywords)
368
  query_vec = vectorizer.transform([query_str])
369
  similarity_scores = cosine_similarity(query_vec, tfidf_matrix).flatten()
370
 
 
371
  ranked_papers = sorted(zip(papers, similarity_scores), key=lambda x: x[1], reverse=True)
372
-
373
- for paper, _ in ranked_papers:
 
 
374
  paper["summary"] = summarizer(paper["abstract"], max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
375
-
376
  return [paper for paper, _ in ranked_papers[:num_results]]
377
 
378
  except Exception as e:
379
  return [{"error": f"Error fetching research papers: {str(e)}"}]
380
 
 
381
  @tool
382
  def get_citation_count(paper_title: str) -> int:
383
- """Fetches citation count from Semantic Scholar API."""
 
 
 
 
 
 
 
 
384
  try:
385
- url = f"https://api.semanticscholar.org/v1/paper/search?query={urllib.parse.quote(paper_title)}"
386
- response = requests.get(url).json()
387
- return response["results"][0].get("citationCount", 0) if "results" in response else 0
388
- except:
 
 
 
 
 
 
389
  return 0
390
 
 
391
  @tool
392
  def rank_papers_by_citations(papers: list) -> list:
393
- """Ranks papers based on citation count and TF-IDF similarity."""
 
 
 
 
 
 
 
 
394
  for paper in papers:
395
  paper["citations"] = get_citation_count(paper["title"])
396
- return sorted(papers, key=lambda x: (x["citations"], x["tfidf_score"]), reverse=True)
 
397
 
398
- # AI Model
399
  model = HfApiModel(
400
  max_tokens=2096,
401
  temperature=0.5,
@@ -403,11 +442,11 @@ model = HfApiModel(
403
  custom_role_conversions=None,
404
  )
405
 
406
- # Load prompt templates
407
  with open("prompts.yaml", 'r') as stream:
408
  prompt_templates = yaml.safe_load(stream)
409
 
410
- # Create the AI Agent
411
  agent = CodeAgent(
412
  model=model,
413
  tools=[fetch_latest_arxiv_papers, get_citation_count, rank_papers_by_citations],
@@ -420,7 +459,8 @@ agent = CodeAgent(
420
  prompt_templates=prompt_templates
421
  )
422
 
423
- # Gradio UI
 
424
  with gr.Blocks() as demo:
425
  gr.Markdown("# ScholarAgent")
426
  keyword_input = gr.Textbox(label="Enter keywords or full sentences", placeholder="e.g., deep learning, reinforcement learning")
@@ -430,15 +470,17 @@ with gr.Blocks() as demo:
430
  def search_papers(user_input):
431
  keywords = [kw.strip() for kw in user_input.split(",") if kw.strip()]
432
  results = fetch_latest_arxiv_papers(keywords, num_results=3)
433
- if isinstance(results, list) and len(results) > 0 and "error" in results[0]:
 
434
  return results[0]["error"]
 
435
  return "\n\n".join([
436
  f"---\n\n"
437
  f"📌 **Title:** {paper['title']}\n\n"
438
  f"👨‍🔬 **Authors:** {paper['authors']}\n\n"
439
  f"📅 **Year:** {paper['year']}\n\n"
440
- f"📖 **Summary:** {paper['summary']}\n\n"
441
- f"🔢 **Citations:** {paper['citations']}\n\n"
442
  f"[🔗 Read Full Paper]({paper['link']})\n\n"
443
  for paper in results
444
  ])
@@ -446,6 +488,7 @@ with gr.Blocks() as demo:
446
  search_button.click(search_papers, inputs=[keyword_input], outputs=[output_display])
447
  print("DEBUG: Gradio UI is running. Waiting for user input...")
448
 
449
- # Launch Gradio App
450
  demo.launch()
451
 
 
 
303
 
304
  # # Launch Gradio App
305
  # demo.launch()
306
+
307
+ """------Enhanced ScholarAgent with Fixes and Features-----"""
308
  import feedparser
309
  import urllib.parse
310
  import yaml
 
315
  import gradio as gr
316
  from smolagents import CodeAgent, HfApiModel, tool
317
  import nltk
318
+ from transformers import pipeline
319
 
320
+ # ✅ Ensure necessary NLTK data is downloaded
321
  nltk.download("stopwords")
322
  nltk.download("punkt")
323
  from nltk.corpus import stopwords
 
324
 
325
+ # GPT Summarization Pipeline
326
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
327
 
328
  @tool
 
341
  - "year" (str): The year of publication.
342
  - "abstract" (str): A summary of the paper.
343
  - "link" (str): A URL to the full paper.
344
+ - "citations" (int): Number of citations (from Semantic Scholar).
345
+ - "summary" (str): A GPT-generated summary of the abstract.
346
  """
347
  try:
348
+ # Construct the query for ArXiv API
349
+ query = "+AND+".join([f"all:{kw}" for kw in keywords])
350
  query_encoded = urllib.parse.quote(query)
351
  url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
352
 
353
+ # ✅ Fetch papers from ArXiv
354
  feed = feedparser.parse(url)
355
  papers = []
356
+
357
+ # ✅ Extract papers
358
  for entry in feed.entries:
359
+ paper = {
360
  "title": entry.title,
361
  "authors": ", ".join(author.name for author in entry.authors),
362
  "year": entry.published[:4],
363
  "abstract": entry.summary,
364
+ "link": entry.link,
365
+ }
366
+ paper["citations"] = get_citation_count(paper["title"]) # ✅ Fetch citation count
367
+ papers.append(paper)
368
+
369
  if not papers:
370
  return [{"error": "No results found. Try different keywords."}]
371
 
372
+ # ✅ TF-IDF Vectorization
373
  corpus = [paper["title"] + " " + paper["abstract"] for paper in papers]
374
+ vectorizer = TfidfVectorizer(stop_words=stopwords.words('english'))
375
  tfidf_matrix = vectorizer.fit_transform(corpus)
376
 
377
+ # ✅ Transform Query into TF-IDF Vector
378
  query_str = " ".join(keywords)
379
  query_vec = vectorizer.transform([query_str])
380
  similarity_scores = cosine_similarity(query_vec, tfidf_matrix).flatten()
381
 
382
+ # ✅ Sort papers based on similarity score
383
  ranked_papers = sorted(zip(papers, similarity_scores), key=lambda x: x[1], reverse=True)
384
+
385
+ # Assign TF-IDF scores and generate summaries
386
+ for paper, score in ranked_papers:
387
+ paper["tfidf_score"] = score
388
  paper["summary"] = summarizer(paper["abstract"], max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
389
+
390
  return [paper for paper, _ in ranked_papers[:num_results]]
391
 
392
  except Exception as e:
393
  return [{"error": f"Error fetching research papers: {str(e)}"}]
394
 
395
+
396
  @tool
397
  def get_citation_count(paper_title: str) -> int:
398
+ """
399
+ Fetches citation count from Semantic Scholar API.
400
+
401
+ Args:
402
+ paper_title (str): Title of the research paper.
403
+
404
+ Returns:
405
+ int: Citation count (default 0 if not found).
406
+ """
407
  try:
408
+ base_url = "https://api.semanticscholar.org/graph/v1/paper/search"
409
+ params = {"query": paper_title, "fields": "citationCount"}
410
+ response = requests.get(base_url, params=params).json()
411
+
412
+ if "data" in response and response["data"]:
413
+ return response["data"][0].get("citationCount", 0)
414
+ return 0 # Default to 0 if no data found
415
+
416
+ except Exception as e:
417
+ print(f"ERROR fetching citation count: {e}")
418
  return 0
419
 
420
+
421
  @tool
422
  def rank_papers_by_citations(papers: list) -> list:
423
+ """
424
+ Ranks papers based on citation count and TF-IDF similarity.
425
+
426
+ Args:
427
+ papers (list): List of research papers.
428
+
429
+ Returns:
430
+ list: Papers sorted by citation count and TF-IDF score.
431
+ """
432
  for paper in papers:
433
  paper["citations"] = get_citation_count(paper["title"])
434
+ return sorted(papers, key=lambda x: (x["citations"], x.get("tfidf_score", 0)), reverse=True)
435
+
436
 
437
+ # AI Model
438
  model = HfApiModel(
439
  max_tokens=2096,
440
  temperature=0.5,
 
442
  custom_role_conversions=None,
443
  )
444
 
445
+ # Load prompt templates
446
  with open("prompts.yaml", 'r') as stream:
447
  prompt_templates = yaml.safe_load(stream)
448
 
449
+ # Create the AI Agent
450
  agent = CodeAgent(
451
  model=model,
452
  tools=[fetch_latest_arxiv_papers, get_citation_count, rank_papers_by_citations],
 
459
  prompt_templates=prompt_templates
460
  )
461
 
462
+
463
+ # ✅ Gradio UI
464
  with gr.Blocks() as demo:
465
  gr.Markdown("# ScholarAgent")
466
  keyword_input = gr.Textbox(label="Enter keywords or full sentences", placeholder="e.g., deep learning, reinforcement learning")
 
470
  def search_papers(user_input):
471
  keywords = [kw.strip() for kw in user_input.split(",") if kw.strip()]
472
  results = fetch_latest_arxiv_papers(keywords, num_results=3)
473
+
474
+ if isinstance(results, list) and results and "error" in results[0]:
475
  return results[0]["error"]
476
+
477
  return "\n\n".join([
478
  f"---\n\n"
479
  f"📌 **Title:** {paper['title']}\n\n"
480
  f"👨‍🔬 **Authors:** {paper['authors']}\n\n"
481
  f"📅 **Year:** {paper['year']}\n\n"
482
+ f"📖 **Summary:** {paper.get('summary', 'No summary available')[:500]}... *(truncated)*\n\n"
483
+ f"🔢 **Citations:** {paper.get('citations', 0)}\n\n"
484
  f"[🔗 Read Full Paper]({paper['link']})\n\n"
485
  for paper in results
486
  ])
 
488
  search_button.click(search_papers, inputs=[keyword_input], outputs=[output_display])
489
  print("DEBUG: Gradio UI is running. Waiting for user input...")
490
 
491
+ # Launch Gradio App
492
  demo.launch()
493
 
494
+