Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import asyncio
|
2 |
import json
|
3 |
-
import os
|
4 |
import urllib.parse
|
5 |
import re
|
6 |
import xml.etree.ElementTree as ET
|
@@ -388,9 +387,9 @@ class CitationGenerator:
|
|
388 |
return key
|
389 |
|
390 |
async def process_text(self, text: str, num_queries: int, citations_per_query: int,
|
391 |
-
use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str]:
|
392 |
if not (use_arxiv or use_crossref):
|
393 |
-
return "Please select at least one source (ArXiv or CrossRef)", ""
|
394 |
|
395 |
num_queries = min(max(1, num_queries), self.config.max_queries)
|
396 |
citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
|
@@ -449,7 +448,7 @@ class CitationGenerator:
|
|
449 |
logger.error(f"Error inserting citations: {e}")
|
450 |
return input_data["text"], ""
|
451 |
|
452 |
-
async def agent_run(input_data: dict):
|
453 |
queries = await generate_queries_tool(input_data)
|
454 |
papers = await search_papers_tool({
|
455 |
"queries": queries,
|
@@ -458,29 +457,29 @@ class CitationGenerator:
|
|
458 |
"use_crossref": input_data["use_crossref"]
|
459 |
})
|
460 |
if not papers:
|
461 |
-
return input_data["text"], ""
|
462 |
cited_text, final_bibtex = await cite_text_tool({
|
463 |
"text": input_data["text"],
|
464 |
"papers": papers
|
465 |
})
|
466 |
-
return cited_text, final_bibtex
|
467 |
|
468 |
-
final_text, final_bibtex = await agent_run({
|
469 |
"text": text,
|
470 |
"num_queries": num_queries,
|
471 |
"citations_per_query": citations_per_query,
|
472 |
"use_arxiv": use_arxiv,
|
473 |
"use_crossref": use_crossref
|
474 |
})
|
475 |
-
return final_text, final_bibtex
|
476 |
|
477 |
def create_gradio_interface() -> gr.Interface:
|
478 |
async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
|
479 |
-
use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
|
480 |
if not api_key.strip():
|
481 |
-
return "Please enter your Gemini API Key.", ""
|
482 |
if not text.strip():
|
483 |
-
return "Please enter text to process", ""
|
484 |
try:
|
485 |
config = Config(gemini_api_key=api_key)
|
486 |
citation_gen = CitationGenerator(config)
|
@@ -489,9 +488,9 @@ def create_gradio_interface() -> gr.Interface:
|
|
489 |
use_arxiv=use_arxiv, use_crossref=use_crossref
|
490 |
)
|
491 |
except ValueError as e:
|
492 |
-
return f"Input validation error: {str(e)}", ""
|
493 |
except Exception as e:
|
494 |
-
return f"Error: {str(e)}", ""
|
495 |
|
496 |
css = """
|
497 |
:root {
|
@@ -677,11 +676,16 @@ def create_gradio_interface() -> gr.Interface:
|
|
677 |
lines=10,
|
678 |
show_copy_button=True
|
679 |
)
|
|
|
|
|
|
|
|
|
|
|
680 |
|
681 |
process_btn.click(
|
682 |
fn=process,
|
683 |
inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
|
684 |
-
outputs=[cited_text, bibtex]
|
685 |
)
|
686 |
|
687 |
return demo
|
@@ -693,4 +697,4 @@ if __name__ == "__main__":
|
|
693 |
except KeyboardInterrupt:
|
694 |
print("\nShutting down server...")
|
695 |
except Exception as e:
|
696 |
-
print(f"Error starting server: {str(e)}")
|
|
|
1 |
import asyncio
|
2 |
import json
|
|
|
3 |
import urllib.parse
|
4 |
import re
|
5 |
import xml.etree.ElementTree as ET
|
|
|
387 |
return key
|
388 |
|
389 |
async def process_text(self, text: str, num_queries: int, citations_per_query: int,
|
390 |
+
use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str, str]:
|
391 |
if not (use_arxiv or use_crossref):
|
392 |
+
return "Please select at least one source (ArXiv or CrossRef)", "", ""
|
393 |
|
394 |
num_queries = min(max(1, num_queries), self.config.max_queries)
|
395 |
citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
|
|
|
448 |
logger.error(f"Error inserting citations: {e}")
|
449 |
return input_data["text"], ""
|
450 |
|
451 |
+
async def agent_run(input_data: dict) -> tuple[str, str, str]:
|
452 |
queries = await generate_queries_tool(input_data)
|
453 |
papers = await search_papers_tool({
|
454 |
"queries": queries,
|
|
|
457 |
"use_crossref": input_data["use_crossref"]
|
458 |
})
|
459 |
if not papers:
|
460 |
+
return input_data["text"], "", "\n".join([f"- {q}" for q in queries])
|
461 |
cited_text, final_bibtex = await cite_text_tool({
|
462 |
"text": input_data["text"],
|
463 |
"papers": papers
|
464 |
})
|
465 |
+
return cited_text, final_bibtex, "\n".join([f"- {q}" for q in queries])
|
466 |
|
467 |
+
final_text, final_bibtex, final_queries = await agent_run({
|
468 |
"text": text,
|
469 |
"num_queries": num_queries,
|
470 |
"citations_per_query": citations_per_query,
|
471 |
"use_arxiv": use_arxiv,
|
472 |
"use_crossref": use_crossref
|
473 |
})
|
474 |
+
return final_text, final_bibtex, final_queries
|
475 |
|
476 |
def create_gradio_interface() -> gr.Interface:
|
477 |
async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
|
478 |
+
use_arxiv: bool, use_crossref: bool) -> tuple[str, str, str]:
|
479 |
if not api_key.strip():
|
480 |
+
return "Please enter your Gemini API Key.", "", ""
|
481 |
if not text.strip():
|
482 |
+
return "Please enter text to process", "", ""
|
483 |
try:
|
484 |
config = Config(gemini_api_key=api_key)
|
485 |
citation_gen = CitationGenerator(config)
|
|
|
488 |
use_arxiv=use_arxiv, use_crossref=use_crossref
|
489 |
)
|
490 |
except ValueError as e:
|
491 |
+
return f"Input validation error: {str(e)}", "", ""
|
492 |
except Exception as e:
|
493 |
+
return f"Error: {str(e)}", "", ""
|
494 |
|
495 |
css = """
|
496 |
:root {
|
|
|
676 |
lines=10,
|
677 |
show_copy_button=True
|
678 |
)
|
679 |
+
queries_text = gr.Textbox(
|
680 |
+
label="Generated Queries",
|
681 |
+
lines=5,
|
682 |
+
show_copy_button=True
|
683 |
+
)
|
684 |
|
685 |
process_btn.click(
|
686 |
fn=process,
|
687 |
inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
|
688 |
+
outputs=[cited_text, bibtex, queries_text]
|
689 |
)
|
690 |
|
691 |
return demo
|
|
|
697 |
except KeyboardInterrupt:
|
698 |
print("\nShutting down server...")
|
699 |
except Exception as e:
|
700 |
+
print(f"Error starting server: {str(e)}")
|