yipengsun commited on
Commit
b5d35b3
·
verified ·
1 Parent(s): 663ebf5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import asyncio
2
  import json
3
- import os
4
  import urllib.parse
5
  import re
6
  import xml.etree.ElementTree as ET
@@ -388,9 +387,9 @@ class CitationGenerator:
388
  return key
389
 
390
  async def process_text(self, text: str, num_queries: int, citations_per_query: int,
391
- use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str]:
392
  if not (use_arxiv or use_crossref):
393
- return "Please select at least one source (ArXiv or CrossRef)", ""
394
 
395
  num_queries = min(max(1, num_queries), self.config.max_queries)
396
  citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
@@ -449,7 +448,7 @@ class CitationGenerator:
449
  logger.error(f"Error inserting citations: {e}")
450
  return input_data["text"], ""
451
 
452
- async def agent_run(input_data: dict):
453
  queries = await generate_queries_tool(input_data)
454
  papers = await search_papers_tool({
455
  "queries": queries,
@@ -458,29 +457,29 @@ class CitationGenerator:
458
  "use_crossref": input_data["use_crossref"]
459
  })
460
  if not papers:
461
- return input_data["text"], ""
462
  cited_text, final_bibtex = await cite_text_tool({
463
  "text": input_data["text"],
464
  "papers": papers
465
  })
466
- return cited_text, final_bibtex
467
 
468
- final_text, final_bibtex = await agent_run({
469
  "text": text,
470
  "num_queries": num_queries,
471
  "citations_per_query": citations_per_query,
472
  "use_arxiv": use_arxiv,
473
  "use_crossref": use_crossref
474
  })
475
- return final_text, final_bibtex
476
 
477
  def create_gradio_interface() -> gr.Interface:
478
  async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
479
- use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
480
  if not api_key.strip():
481
- return "Please enter your Gemini API Key.", ""
482
  if not text.strip():
483
- return "Please enter text to process", ""
484
  try:
485
  config = Config(gemini_api_key=api_key)
486
  citation_gen = CitationGenerator(config)
@@ -489,9 +488,9 @@ def create_gradio_interface() -> gr.Interface:
489
  use_arxiv=use_arxiv, use_crossref=use_crossref
490
  )
491
  except ValueError as e:
492
- return f"Input validation error: {str(e)}", ""
493
  except Exception as e:
494
- return f"Error: {str(e)}", ""
495
 
496
  css = """
497
  :root {
@@ -677,11 +676,16 @@ def create_gradio_interface() -> gr.Interface:
677
  lines=10,
678
  show_copy_button=True
679
  )
 
 
 
 
 
680
 
681
  process_btn.click(
682
  fn=process,
683
  inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
684
- outputs=[cited_text, bibtex]
685
  )
686
 
687
  return demo
@@ -693,4 +697,4 @@ if __name__ == "__main__":
693
  except KeyboardInterrupt:
694
  print("\nShutting down server...")
695
  except Exception as e:
696
- print(f"Error starting server: {str(e)}")
 
1
  import asyncio
2
  import json
 
3
  import urllib.parse
4
  import re
5
  import xml.etree.ElementTree as ET
 
387
  return key
388
 
389
  async def process_text(self, text: str, num_queries: int, citations_per_query: int,
390
+ use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str, str]:
391
  if not (use_arxiv or use_crossref):
392
+ return "Please select at least one source (ArXiv or CrossRef)", "", ""
393
 
394
  num_queries = min(max(1, num_queries), self.config.max_queries)
395
  citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
 
448
  logger.error(f"Error inserting citations: {e}")
449
  return input_data["text"], ""
450
 
451
+ async def agent_run(input_data: dict) -> tuple[str, str, str]:
452
  queries = await generate_queries_tool(input_data)
453
  papers = await search_papers_tool({
454
  "queries": queries,
 
457
  "use_crossref": input_data["use_crossref"]
458
  })
459
  if not papers:
460
+ return input_data["text"], "", "\n".join([f"- {q}" for q in queries])
461
  cited_text, final_bibtex = await cite_text_tool({
462
  "text": input_data["text"],
463
  "papers": papers
464
  })
465
+ return cited_text, final_bibtex, "\n".join([f"- {q}" for q in queries])
466
 
467
+ final_text, final_bibtex, final_queries = await agent_run({
468
  "text": text,
469
  "num_queries": num_queries,
470
  "citations_per_query": citations_per_query,
471
  "use_arxiv": use_arxiv,
472
  "use_crossref": use_crossref
473
  })
474
+ return final_text, final_bibtex, final_queries
475
 
476
  def create_gradio_interface() -> gr.Interface:
477
  async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
478
+ use_arxiv: bool, use_crossref: bool) -> tuple[str, str, str]:
479
  if not api_key.strip():
480
+ return "Please enter your Gemini API Key.", "", ""
481
  if not text.strip():
482
+ return "Please enter text to process", "", ""
483
  try:
484
  config = Config(gemini_api_key=api_key)
485
  citation_gen = CitationGenerator(config)
 
488
  use_arxiv=use_arxiv, use_crossref=use_crossref
489
  )
490
  except ValueError as e:
491
+ return f"Input validation error: {str(e)}", "", ""
492
  except Exception as e:
493
+ return f"Error: {str(e)}", "", ""
494
 
495
  css = """
496
  :root {
 
676
  lines=10,
677
  show_copy_button=True
678
  )
679
+ queries_text = gr.Textbox(
680
+ label="Generated Queries",
681
+ lines=5,
682
+ show_copy_button=True
683
+ )
684
 
685
  process_btn.click(
686
  fn=process,
687
  inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
688
+ outputs=[cited_text, bibtex, queries_text]
689
  )
690
 
691
  return demo
 
697
  except KeyboardInterrupt:
698
  print("\nShutting down server...")
699
  except Exception as e:
700
+ print(f"Error starting server: {str(e)}")