yipengsun commited on
Commit
05d713a
·
verified ·
1 Parent(s): f1fefc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -55
app.py CHANGED
@@ -4,7 +4,7 @@ import urllib.parse
4
  import re
5
  import xml.etree.ElementTree as ET
6
  from dataclasses import dataclass, field
7
- from typing import Dict, List, Optional
8
  import sys
9
  from loguru import logger
10
 
@@ -12,13 +12,23 @@ import aiohttp
12
  import gradio as gr
13
 
14
  from langchain.prompts import PromptTemplate
15
-
16
  from langchain_google_genai import ChatGoogleGenerativeAI
17
 
18
  import bibtexparser
19
  from bibtexparser.bwriter import BibTexWriter
20
  from bibtexparser.bibdatabase import BibDatabase
21
 
 
 
 
 
 
 
 
 
 
 
 
22
  @dataclass
23
  class Config:
24
  gemini_api_key: str
@@ -28,18 +38,25 @@ class Config:
28
  max_citations_per_query: int = 10
29
  arxiv_base_url: str = 'http://export.arxiv.org/api/query?'
30
  crossref_base_url: str = 'https://api.crossref.org/works'
31
- default_headers: dict = field(default_factory=lambda: {
32
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
33
  })
34
  log_level: str = 'DEBUG'
35
 
 
36
  class ArxivXmlParser:
 
 
 
37
  NS = {
38
  'atom': 'http://www.w3.org/2005/Atom',
39
  'arxiv': 'http://arxiv.org/schemas/atom'
40
  }
41
 
42
- def parse_papers(self, data: str) -> List[Dict]:
 
 
 
43
  try:
44
  root = ET.fromstring(data)
45
  papers = []
@@ -52,12 +69,15 @@ class ArxivXmlParser:
52
  logger.error(f"Error parsing ArXiv XML: {e}")
53
  return []
54
 
55
- def parse_entry(self, entry) -> Optional[dict]:
 
 
 
56
  try:
57
  title_node = entry.find('atom:title', self.NS)
58
  if title_node is None:
59
  return None
60
- title = title_node.text.strip()
61
 
62
  authors = []
63
  for author in entry.findall('atom:author', self.NS):
@@ -66,15 +86,15 @@ class ArxivXmlParser:
66
  authors.append(self._format_author_name(author_name_node.text.strip()))
67
 
68
  arxiv_id_node = entry.find('atom:id', self.NS)
69
- if arxiv_id_node is None:
70
  return None
71
  arxiv_id = arxiv_id_node.text.split('/')[-1]
72
 
73
  published_node = entry.find('atom:published', self.NS)
74
- year = published_node.text[:4] if published_node is not None else "Unknown"
75
 
76
  abstract_node = entry.find('atom:summary', self.NS)
77
- abstract = abstract_node.text.strip() if abstract_node is not None else ""
78
 
79
  bibtex_key = f"{authors[0].split(',')[0]}{arxiv_id.replace('.', '')}" if authors else f"unknown{arxiv_id.replace('.', '')}"
80
  bibtex_entry = self._generate_bibtex_entry(bibtex_key, title, authors, arxiv_id, year)
@@ -94,12 +114,18 @@ class ArxivXmlParser:
94
 
95
  @staticmethod
96
  def _format_author_name(author: str) -> str:
 
 
 
97
  names = author.split()
98
  if len(names) > 1:
99
  return f"{names[-1]}, {' '.join(names[:-1])}"
100
  return author
101
 
102
  def _generate_bibtex_entry(self, key: str, title: str, authors: List[str], arxiv_id: str, year: str) -> str:
 
 
 
103
  db = BibDatabase()
104
  db.entries = [{
105
  'ENTRYTYPE': 'article',
@@ -109,13 +135,15 @@ class ArxivXmlParser:
109
  'journal': f'arXiv preprint arXiv:{arxiv_id}',
110
  'year': year
111
  }]
112
- writer = BibTexWriter()
113
- writer.indent = ' '
114
- writer.comma_first = False
115
  return writer.write(db).strip()
116
 
 
117
  class AsyncContextManager:
118
- async def __aenter__(self):
 
 
 
119
  self._session = aiohttp.ClientSession()
120
  return self._session
121
 
@@ -123,13 +151,17 @@ class AsyncContextManager:
123
  if self._session:
124
  await self._session.close()
125
 
 
126
  class CitationGenerator:
127
- def __init__(self, config: Config):
 
 
 
128
  self.config = config
129
  self.xml_parser = ArxivXmlParser()
130
  self.async_context = AsyncContextManager()
131
  self.llm = ChatGoogleGenerativeAI(
132
- model="gemini-2.0-flash-exp",
133
  temperature=0.3,
134
  google_api_key=config.gemini_api_key,
135
  streaming=True
@@ -165,6 +197,9 @@ class CitationGenerator:
165
  logger.add(sys.stderr, level=config.log_level)
166
 
167
  async def generate_queries(self, text: str, num_queries: int) -> List[str]:
 
 
 
168
  input_map = {
169
  "text": text,
170
  "num_queries": num_queries
@@ -186,14 +221,15 @@ class CitationGenerator:
186
  lines = [line.strip() for line in content.split('\n')
187
  if line.strip() and not line.strip().startswith(('[', ']'))]
188
  return lines[:num_queries]
189
-
190
  return ["deep learning neural networks"]
191
-
192
  except Exception as e:
193
  logger.error(f"Error generating queries: {e}")
194
  return ["deep learning neural networks"]
195
 
196
- async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
 
 
 
197
  try:
198
  params = {
199
  'search_query': f'all:{urllib.parse.quote(query)}',
@@ -202,8 +238,9 @@ class CitationGenerator:
202
  'sortBy': 'relevance',
203
  'sortOrder': 'descending'
204
  }
 
205
  async with session.get(
206
- self.config.arxiv_base_url + urllib.parse.urlencode(params),
207
  headers=self.config.default_headers,
208
  timeout=30
209
  ) as response:
@@ -215,20 +252,23 @@ class CitationGenerator:
215
  return []
216
 
217
  async def fix_author_name(self, author: str) -> str:
 
 
 
218
  if not re.search(r'[�]', author):
219
  return author
220
  try:
221
  prompt = f"""Fix this author name that contains corrupted characters (�):
222
 
223
- Name: {author}
224
 
225
- Requirements:
226
- 1. Return ONLY the fixed author name
227
- 2. Use proper diacritical marks for names
228
- 3. Consider common name patterns and languages
229
- 4. If unsure, use the most likely letter
230
- 5. Maintain the format: "Lastname, Firstname"
231
- """
232
  response = await self.llm.ainvoke(prompt)
233
  fixed_name = response.content.strip()
234
  return fixed_name if fixed_name else author
@@ -237,6 +277,9 @@ class CitationGenerator:
237
  return author
238
 
239
  async def format_bibtex_author_names(self, text: str) -> str:
 
 
 
240
  try:
241
  bib_database = bibtexparser.loads(text)
242
  for entry in bib_database.entries:
@@ -247,15 +290,16 @@ class CitationGenerator:
247
  fixed_author = await self.fix_author_name(author)
248
  cleaned_authors.append(fixed_author)
249
  entry['author'] = ' and '.join(cleaned_authors)
250
- writer = BibTexWriter()
251
- writer.indent = ' '
252
- writer.comma_first = False
253
  return writer.write(bib_database).strip()
254
  except Exception as e:
255
  logger.error(f"Error cleaning BibTeX special characters: {e}")
256
  return text
257
 
258
- async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
 
 
 
259
  try:
260
  cleaned_query = query.replace("'", "").replace('"', "")
261
  if ' ' in cleaned_query:
@@ -316,7 +360,6 @@ class CitationGenerator:
316
  continue
317
 
318
  bibtex_text = await bibtex_response.text()
319
-
320
  bib_database = bibtexparser.loads(bibtex_text)
321
  if not bib_database.entries:
322
  continue
@@ -335,9 +378,7 @@ class CitationGenerator:
335
  entry['ID'] = key
336
  existing_keys.add(key)
337
 
338
- writer = BibTexWriter()
339
- writer.indent = ' '
340
- writer.comma_first = False
341
  formatted_bibtex = writer.write(bib_database).strip()
342
 
343
  papers.append({
@@ -364,7 +405,10 @@ class CitationGenerator:
364
  logger.error(f"Error searching CrossRef: {e}")
365
  return []
366
 
367
- def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
 
 
 
368
  entry_type = entry.get('ENTRYTYPE', '').lower()
369
  author_field = entry.get('author', '')
370
  year = entry.get('year', '')
@@ -373,10 +417,10 @@ class CitationGenerator:
373
 
374
  if entry_type == 'inbook':
375
  booktitle = entry.get('booktitle', '')
376
- title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
377
  else:
378
  title = entry.get('title', '')
379
- title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
380
 
381
  base_key = f"{first_author_last_name}{year}{title_word}"
382
  key = base_key
@@ -387,17 +431,20 @@ class CitationGenerator:
387
  return key
388
 
389
  async def process_text(self, text: str, num_queries: int, citations_per_query: int,
390
- use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str, str]:
 
 
 
391
  if not (use_arxiv or use_crossref):
392
  return "Please select at least one source (ArXiv or CrossRef)", "", ""
393
 
394
  num_queries = min(max(1, num_queries), self.config.max_queries)
395
  citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
396
 
397
- async def generate_queries_tool(input_data: dict):
398
  return await self.generate_queries(input_data["text"], input_data["num_queries"])
399
 
400
- async def search_papers_tool(input_data: dict):
401
  queries = input_data["queries"]
402
  papers = []
403
  async with self.async_context as session:
@@ -411,7 +458,7 @@ class CitationGenerator:
411
  for r in results:
412
  if not isinstance(r, Exception):
413
  papers.extend(r)
414
- # Deduplicate
415
  unique_papers = []
416
  seen_keys = set()
417
  for p in papers:
@@ -420,7 +467,7 @@ class CitationGenerator:
420
  unique_papers.append(p)
421
  return unique_papers
422
 
423
- async def cite_text_tool(input_data: dict):
424
  try:
425
  citation_input = {
426
  "text": input_data["text"],
@@ -430,7 +477,6 @@ class CitationGenerator:
430
  response = await self.llm.ainvoke(prompt)
431
  cited_text = response.content.strip()
432
 
433
- # Aggregate BibTeX entries
434
  bib_database = BibDatabase()
435
  for p in input_data["papers"]:
436
  if 'bibtex_entry' in p:
@@ -439,16 +485,14 @@ class CitationGenerator:
439
  bib_database.entries.append(bib_db.entries[0])
440
  else:
441
  logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
442
- writer = BibTexWriter()
443
- writer.indent = ' '
444
- writer.comma_first = False
445
  bibtex_entries = writer.write(bib_database).strip()
446
  return cited_text, bibtex_entries
447
  except Exception as e:
448
  logger.error(f"Error inserting citations: {e}")
449
  return input_data["text"], ""
450
 
451
- async def agent_run(input_data: dict) -> tuple[str, str, str]:
452
  queries = await generate_queries_tool(input_data)
453
  papers = await search_papers_tool({
454
  "queries": queries,
@@ -473,9 +517,13 @@ class CitationGenerator:
473
  })
474
  return final_text, final_bibtex, final_queries
475
 
 
476
  def create_gradio_interface() -> gr.Interface:
 
 
 
477
  async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
478
- use_arxiv: bool, use_crossref: bool) -> tuple[str, str, str]:
479
  if not api_key.strip():
480
  return "Please enter your Gemini API Key.", "", ""
481
  if not text.strip():
@@ -494,14 +542,14 @@ def create_gradio_interface() -> gr.Interface:
494
 
495
  css = """
496
  :root {
497
- /* Modern, sophisticated color palette */
498
  --primary-bg: #F8F9FA;
499
  --secondary-bg: #FFFFFF;
500
- --accent-1: #4A90E2; /* Refined blue */
501
- --accent-2: #50C878; /* Emerald green */
502
- --accent-3: #F5B041; /* Warm orange */
503
- --text-primary: #2C3E50; /* Deep blue-gray */
504
- --text-secondary: #566573; /* Medium gray */
505
  --border: #E5E7E9;
506
  --shadow: rgba(0, 0, 0, 0.1);
507
  }
@@ -690,6 +738,7 @@ def create_gradio_interface() -> gr.Interface:
690
 
691
  return demo
692
 
 
693
  if __name__ == "__main__":
694
  demo = create_gradio_interface()
695
  try:
 
4
  import re
5
  import xml.etree.ElementTree as ET
6
  from dataclasses import dataclass, field
7
+ from typing import Dict, List, Optional, Any, Tuple
8
  import sys
9
  from loguru import logger
10
 
 
12
  import gradio as gr
13
 
14
  from langchain.prompts import PromptTemplate
 
15
  from langchain_google_genai import ChatGoogleGenerativeAI
16
 
17
  import bibtexparser
18
  from bibtexparser.bwriter import BibTexWriter
19
  from bibtexparser.bibdatabase import BibDatabase
20
 
21
+
22
+ def get_bibtex_writer() -> BibTexWriter:
23
+ """
24
+ Create and return a configured BibTexWriter instance.
25
+ """
26
+ writer = BibTexWriter()
27
+ writer.indent = ' '
28
+ writer.comma_first = False
29
+ return writer
30
+
31
+
32
  @dataclass
33
  class Config:
34
  gemini_api_key: str
 
38
  max_citations_per_query: int = 10
39
  arxiv_base_url: str = 'http://export.arxiv.org/api/query?'
40
  crossref_base_url: str = 'https://api.crossref.org/works'
41
+ default_headers: Dict[str, str] = field(default_factory=lambda: {
42
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
43
  })
44
  log_level: str = 'DEBUG'
45
 
46
+
47
  class ArxivXmlParser:
48
+ """
49
+ Class to parse ArXiv XML responses.
50
+ """
51
  NS = {
52
  'atom': 'http://www.w3.org/2005/Atom',
53
  'arxiv': 'http://arxiv.org/schemas/atom'
54
  }
55
 
56
+ def parse_papers(self, data: str) -> List[Dict[str, Any]]:
57
+ """
58
+ Parse ArXiv XML data and return a list of paper dictionaries.
59
+ """
60
  try:
61
  root = ET.fromstring(data)
62
  papers = []
 
69
  logger.error(f"Error parsing ArXiv XML: {e}")
70
  return []
71
 
72
+ def parse_entry(self, entry: ET.Element) -> Optional[Dict[str, Any]]:
73
+ """
74
+ Parse a single ArXiv entry element and return a dictionary with paper details.
75
+ """
76
  try:
77
  title_node = entry.find('atom:title', self.NS)
78
  if title_node is None:
79
  return None
80
+ title = title_node.text.strip() if title_node.text else ""
81
 
82
  authors = []
83
  for author in entry.findall('atom:author', self.NS):
 
86
  authors.append(self._format_author_name(author_name_node.text.strip()))
87
 
88
  arxiv_id_node = entry.find('atom:id', self.NS)
89
+ if arxiv_id_node is None or not arxiv_id_node.text:
90
  return None
91
  arxiv_id = arxiv_id_node.text.split('/')[-1]
92
 
93
  published_node = entry.find('atom:published', self.NS)
94
+ year = published_node.text[:4] if (published_node is not None and published_node.text) else "Unknown"
95
 
96
  abstract_node = entry.find('atom:summary', self.NS)
97
+ abstract = abstract_node.text.strip() if (abstract_node is not None and abstract_node.text) else ""
98
 
99
  bibtex_key = f"{authors[0].split(',')[0]}{arxiv_id.replace('.', '')}" if authors else f"unknown{arxiv_id.replace('.', '')}"
100
  bibtex_entry = self._generate_bibtex_entry(bibtex_key, title, authors, arxiv_id, year)
 
114
 
115
  @staticmethod
116
  def _format_author_name(author: str) -> str:
117
+ """
118
+ Format an author name as 'Lastname, Firstname'.
119
+ """
120
  names = author.split()
121
  if len(names) > 1:
122
  return f"{names[-1]}, {' '.join(names[:-1])}"
123
  return author
124
 
125
  def _generate_bibtex_entry(self, key: str, title: str, authors: List[str], arxiv_id: str, year: str) -> str:
126
+ """
127
+ Generate a BibTeX entry for a paper.
128
+ """
129
  db = BibDatabase()
130
  db.entries = [{
131
  'ENTRYTYPE': 'article',
 
135
  'journal': f'arXiv preprint arXiv:{arxiv_id}',
136
  'year': year
137
  }]
138
+ writer = get_bibtex_writer()
 
 
139
  return writer.write(db).strip()
140
 
141
+
142
  class AsyncContextManager:
143
+ """
144
+ Asynchronous context manager to handle aiohttp ClientSession.
145
+ """
146
+ async def __aenter__(self) -> aiohttp.ClientSession:
147
  self._session = aiohttp.ClientSession()
148
  return self._session
149
 
 
151
  if self._session:
152
  await self._session.close()
153
 
154
+
155
  class CitationGenerator:
156
+ """
157
+ Class that handles generating citations using AI and searching for academic papers.
158
+ """
159
+ def __init__(self, config: Config) -> None:
160
  self.config = config
161
  self.xml_parser = ArxivXmlParser()
162
  self.async_context = AsyncContextManager()
163
  self.llm = ChatGoogleGenerativeAI(
164
+ model="gemini-2.0-flash",
165
  temperature=0.3,
166
  google_api_key=config.gemini_api_key,
167
  streaming=True
 
197
  logger.add(sys.stderr, level=config.log_level)
198
 
199
  async def generate_queries(self, text: str, num_queries: int) -> List[str]:
200
+ """
201
+ Generate a list of academic search queries from the input text.
202
+ """
203
  input_map = {
204
  "text": text,
205
  "num_queries": num_queries
 
221
  lines = [line.strip() for line in content.split('\n')
222
  if line.strip() and not line.strip().startswith(('[', ']'))]
223
  return lines[:num_queries]
 
224
  return ["deep learning neural networks"]
 
225
  except Exception as e:
226
  logger.error(f"Error generating queries: {e}")
227
  return ["deep learning neural networks"]
228
 
229
+ async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict[str, Any]]:
230
+ """
231
+ Search ArXiv for papers matching the query.
232
+ """
233
  try:
234
  params = {
235
  'search_query': f'all:{urllib.parse.quote(query)}',
 
238
  'sortBy': 'relevance',
239
  'sortOrder': 'descending'
240
  }
241
+ url = self.config.arxiv_base_url + urllib.parse.urlencode(params)
242
  async with session.get(
243
+ url,
244
  headers=self.config.default_headers,
245
  timeout=30
246
  ) as response:
 
252
  return []
253
 
254
  async def fix_author_name(self, author: str) -> str:
255
+ """
256
+ Correct an author name that contains corrupted characters.
257
+ """
258
  if not re.search(r'[�]', author):
259
  return author
260
  try:
261
  prompt = f"""Fix this author name that contains corrupted characters (�):
262
 
263
+ Name: {author}
264
 
265
+ Requirements:
266
+ 1. Return ONLY the fixed author name
267
+ 2. Use proper diacritical marks for names
268
+ 3. Consider common name patterns and languages
269
+ 4. If unsure, use the most likely letter
270
+ 5. Maintain the format: "Lastname, Firstname"
271
+ """
272
  response = await self.llm.ainvoke(prompt)
273
  fixed_name = response.content.strip()
274
  return fixed_name if fixed_name else author
 
277
  return author
278
 
279
  async def format_bibtex_author_names(self, text: str) -> str:
280
+ """
281
+ Clean and format author names in a BibTeX string.
282
+ """
283
  try:
284
  bib_database = bibtexparser.loads(text)
285
  for entry in bib_database.entries:
 
290
  fixed_author = await self.fix_author_name(author)
291
  cleaned_authors.append(fixed_author)
292
  entry['author'] = ' and '.join(cleaned_authors)
293
+ writer = get_bibtex_writer()
 
 
294
  return writer.write(bib_database).strip()
295
  except Exception as e:
296
  logger.error(f"Error cleaning BibTeX special characters: {e}")
297
  return text
298
 
299
+ async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict[str, Any]]:
300
+ """
301
+ Search CrossRef for papers matching the query.
302
+ """
303
  try:
304
  cleaned_query = query.replace("'", "").replace('"', "")
305
  if ' ' in cleaned_query:
 
360
  continue
361
 
362
  bibtex_text = await bibtex_response.text()
 
363
  bib_database = bibtexparser.loads(bibtex_text)
364
  if not bib_database.entries:
365
  continue
 
378
  entry['ID'] = key
379
  existing_keys.add(key)
380
 
381
+ writer = get_bibtex_writer()
 
 
382
  formatted_bibtex = writer.write(bib_database).strip()
383
 
384
  papers.append({
 
405
  logger.error(f"Error searching CrossRef: {e}")
406
  return []
407
 
408
+ def _generate_unique_bibtex_key(self, entry: Dict[str, Any], existing_keys: set) -> str:
409
+ """
410
+ Generate a unique BibTeX key for an entry.
411
+ """
412
  entry_type = entry.get('ENTRYTYPE', '').lower()
413
  author_field = entry.get('author', '')
414
  year = entry.get('year', '')
 
417
 
418
  if entry_type == 'inbook':
419
  booktitle = entry.get('booktitle', '')
420
+ title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle.split() else 'untitled'
421
  else:
422
  title = entry.get('title', '')
423
+ title_word = re.sub(r'\W+', '', title.split()[0]) if title.split() else 'untitled'
424
 
425
  base_key = f"{first_author_last_name}{year}{title_word}"
426
  key = base_key
 
431
  return key
432
 
433
  async def process_text(self, text: str, num_queries: int, citations_per_query: int,
434
+ use_arxiv: bool = True, use_crossref: bool = True) -> Tuple[str, str, str]:
435
+ """
436
+ Process the input text to generate citations and corresponding BibTeX entries.
437
+ """
438
  if not (use_arxiv or use_crossref):
439
  return "Please select at least one source (ArXiv or CrossRef)", "", ""
440
 
441
  num_queries = min(max(1, num_queries), self.config.max_queries)
442
  citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
443
 
444
+ async def generate_queries_tool(input_data: Dict[str, Any]) -> List[str]:
445
  return await self.generate_queries(input_data["text"], input_data["num_queries"])
446
 
447
+ async def search_papers_tool(input_data: Dict[str, Any]) -> List[Dict[str, Any]]:
448
  queries = input_data["queries"]
449
  papers = []
450
  async with self.async_context as session:
 
458
  for r in results:
459
  if not isinstance(r, Exception):
460
  papers.extend(r)
461
+ # Remove duplicate papers
462
  unique_papers = []
463
  seen_keys = set()
464
  for p in papers:
 
467
  unique_papers.append(p)
468
  return unique_papers
469
 
470
+ async def cite_text_tool(input_data: Dict[str, Any]) -> Tuple[str, str]:
471
  try:
472
  citation_input = {
473
  "text": input_data["text"],
 
477
  response = await self.llm.ainvoke(prompt)
478
  cited_text = response.content.strip()
479
 
 
480
  bib_database = BibDatabase()
481
  for p in input_data["papers"]:
482
  if 'bibtex_entry' in p:
 
485
  bib_database.entries.append(bib_db.entries[0])
486
  else:
487
  logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
488
+ writer = get_bibtex_writer()
 
 
489
  bibtex_entries = writer.write(bib_database).strip()
490
  return cited_text, bibtex_entries
491
  except Exception as e:
492
  logger.error(f"Error inserting citations: {e}")
493
  return input_data["text"], ""
494
 
495
+ async def agent_run(input_data: Dict[str, Any]) -> Tuple[str, str, str]:
496
  queries = await generate_queries_tool(input_data)
497
  papers = await search_papers_tool({
498
  "queries": queries,
 
517
  })
518
  return final_text, final_bibtex, final_queries
519
 
520
+
521
  def create_gradio_interface() -> gr.Interface:
522
+ """
523
+ Create and return a Gradio interface for the citation generator.
524
+ """
525
  async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
526
+ use_arxiv: bool, use_crossref: bool) -> Tuple[str, str, str]:
527
  if not api_key.strip():
528
  return "Please enter your Gemini API Key.", "", ""
529
  if not text.strip():
 
542
 
543
  css = """
544
  :root {
545
+ /* Modern color palette */
546
  --primary-bg: #F8F9FA;
547
  --secondary-bg: #FFFFFF;
548
+ --accent-1: #4A90E2;
549
+ --accent-2: #50C878;
550
+ --accent-3: #F5B041;
551
+ --text-primary: #2C3E50;
552
+ --text-secondary: #566573;
553
  --border: #E5E7E9;
554
  --shadow: rgba(0, 0, 0, 0.1);
555
  }
 
738
 
739
  return demo
740
 
741
+
742
  if __name__ == "__main__":
743
  demo = create_gradio_interface()
744
  try: