yipengsun commited on
Commit
5e63073
·
verified ·
1 Parent(s): 6252af5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -125
app.py CHANGED
@@ -11,9 +11,9 @@ from loguru import logger
11
 
12
  import aiohttp
13
  import gradio as gr
14
- from langchain_core.output_parsers import StrOutputParser
15
- from langchain_core.prompts import PromptTemplate
16
- from langchain_core.runnables import RunnablePassthrough
17
  from langchain_google_genai import ChatGoogleGenerativeAI
18
 
19
  import bibtexparser
@@ -32,7 +32,7 @@ class Config:
32
  default_headers: dict = field(default_factory=lambda: {
33
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
34
  })
35
- log_level: str = 'DEBUG' # Add log level configuration
36
 
37
  class ArxivXmlParser:
38
  NS = {
@@ -111,8 +111,8 @@ class ArxivXmlParser:
111
  'year': year
112
  }]
113
  writer = BibTexWriter()
114
- writer.indent = ' ' # Indentation for entries
115
- writer.comma_first = False # Place the comma at the end of lines
116
  return writer.write(db).strip()
117
 
118
  class AsyncContextManager:
@@ -135,13 +135,7 @@ class CitationGenerator:
135
  google_api_key=config.gemini_api_key,
136
  streaming=True
137
  )
138
- self.citation_chain = self._create_citation_chain()
139
- self.generate_queries_chain = self._create_generate_queries_chain()
140
- logger.remove()
141
- logger.add(sys.stderr, level=config.log_level) # Configure logger
142
-
143
- def _create_citation_chain(self):
144
- citation_prompt = PromptTemplate.from_template(
145
  """Insert citations into the provided text using LaTeX \\cite{{key}} commands.
146
 
147
  You must not alter the original wording or structure of the text beyond adding citations.
@@ -154,15 +148,8 @@ class CitationGenerator:
154
  {papers}
155
  """
156
  )
157
- return (
158
- {"text": RunnablePassthrough(), "papers": RunnablePassthrough()}
159
- | citation_prompt
160
- | self.llm
161
- | StrOutputParser()
162
- )
163
-
164
- def _create_generate_queries_chain(self):
165
- generate_queries_prompt = PromptTemplate.from_template(
166
  """Generate {num_queries} diverse academic search queries based on the given text.
167
  The queries should be concise and relevant.
168
 
@@ -174,20 +161,18 @@ class CitationGenerator:
174
  Text: {text}
175
  """
176
  )
177
- return (
178
- {"text": RunnablePassthrough(), "num_queries": RunnablePassthrough()}
179
- | generate_queries_prompt
180
- | self.llm
181
- | StrOutputParser()
182
- )
183
 
184
  async def generate_queries(self, text: str, num_queries: int) -> List[str]:
 
 
 
 
185
  try:
186
- response = await self.generate_queries_chain.ainvoke({
187
- "text": text,
188
- "num_queries": num_queries
189
- })
190
-
191
  content = response.strip()
192
  if not content.startswith('['):
193
  start = content.find('[')
@@ -206,7 +191,7 @@ class CitationGenerator:
206
  return ["deep learning neural networks"]
207
 
208
  except Exception as e:
209
- logger.error(f"Error generating queries: {e}") # Replace print with logger
210
  return ["deep learning neural networks"]
211
 
212
  async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
@@ -218,7 +203,6 @@ class CitationGenerator:
218
  'sortBy': 'relevance',
219
  'sortOrder': 'descending'
220
  }
221
-
222
  async with session.get(
223
  self.config.arxiv_base_url + urllib.parse.urlencode(params),
224
  headers=self.config.default_headers,
@@ -234,7 +218,6 @@ class CitationGenerator:
234
  async def fix_author_name(self, author: str) -> str:
235
  if not re.search(r'[�]', author):
236
  return author
237
-
238
  try:
239
  prompt = f"""Fix this author name that contains corrupted characters (�):
240
 
@@ -244,18 +227,12 @@ class CitationGenerator:
244
  1. Return ONLY the fixed author name
245
  2. Use proper diacritical marks for names
246
  3. Consider common name patterns and languages
247
- 4. If unsure about a character, use the most likely letter
248
  5. Maintain the format: "Lastname, Firstname"
249
-
250
- Example fixes:
251
- - "Gonz�lez" -> "González"
252
- - "Cristi�n" -> "Cristi��n"
253
  """
254
-
255
- response = await self.llm.ainvoke([{"role": "user", "content": prompt}])
256
- fixed_name = response.content.strip()
257
  return fixed_name if fixed_name else author
258
-
259
  except Exception as e:
260
  logger.error(f"Error fixing author name: {e}")
261
  return author
@@ -276,7 +253,7 @@ class CitationGenerator:
276
  writer.comma_first = False
277
  return writer.write(bib_database).strip()
278
  except Exception as e:
279
- logger.error(f"Error cleaning BibTeX special characters: {e}")
280
  return text
281
 
282
  async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
@@ -341,31 +318,24 @@ class CitationGenerator:
341
 
342
  bibtex_text = await bibtex_response.text()
343
 
344
- # Parse the BibTeX entry
345
  bib_database = bibtexparser.loads(bibtex_text)
346
  if not bib_database.entries:
347
  continue
348
  entry = bib_database.entries[0]
349
 
350
- # Check if 'title' or 'booktitle' is present
351
  if 'title' not in entry and 'booktitle' not in entry:
352
- continue # Skip entries without 'title' or 'booktitle'
353
-
354
- # Check if 'author' is present
355
  if 'author' not in entry:
356
- continue # Skip entries without 'author'
357
 
358
- # Extract necessary fields
359
  title = entry.get('title', 'No Title').replace('{', '').replace('}', '')
360
  authors = entry.get('author', 'Unknown').replace('\n', ' ').replace('\t', ' ').strip()
361
  year = entry.get('year', 'Unknown')
362
 
363
- # Generate a unique BibTeX key
364
  key = self._generate_unique_bibtex_key(entry, existing_keys)
365
  entry['ID'] = key
366
  existing_keys.add(key)
367
 
368
- # Use BibTexWriter to format the entry
369
  writer = BibTexWriter()
370
  writer.indent = ' '
371
  writer.comma_first = False
@@ -378,10 +348,9 @@ class CitationGenerator:
378
  'bibtex_key': key,
379
  'bibtex_entry': formatted_bibtex
380
  })
381
- except Exception as e:
382
- logger.error(f"Error processing CrossRef item: {e}") # Replace print with logger
383
- continue
384
 
 
 
385
  return papers
386
 
387
  except aiohttp.ClientError as e:
@@ -393,7 +362,7 @@ class CitationGenerator:
393
  await asyncio.sleep(delay)
394
 
395
  except Exception as e:
396
- logger.error(f"Error searching CrossRef: {e}") # Replace print with logger
397
  return []
398
 
399
  def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
@@ -402,18 +371,15 @@ class CitationGenerator:
402
  year = entry.get('year', '')
403
  authors = [a.strip() for a in author_field.split(' and ')]
404
  first_author_last_name = authors[0].split(',')[0] if authors else 'unknown'
405
-
406
  if entry_type == 'inbook':
407
- # Use 'booktitle' for 'inbook' entries
408
  booktitle = entry.get('booktitle', '')
409
  title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
410
  else:
411
- # Use regular 'title' for other entries
412
  title = entry.get('title', '')
413
  title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
414
-
415
  base_key = f"{first_author_last_name}{year}{title_word}"
416
- # Ensure the key is unique
417
  key = base_key
418
  index = 1
419
  while key in existing_keys:
@@ -422,70 +388,93 @@ class CitationGenerator:
422
  return key
423
 
424
  async def process_text(self, text: str, num_queries: int, citations_per_query: int,
425
- use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str]:
426
  if not (use_arxiv or use_crossref):
427
  return "Please select at least one source (ArXiv or CrossRef)", ""
428
 
429
  num_queries = min(max(1, num_queries), self.config.max_queries)
430
  citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
431
 
432
- queries = await self.generate_queries(text, num_queries)
433
- if not queries:
434
- return text, ""
435
-
436
- async with self.async_context as session:
437
- search_tasks = []
438
- for query in queries:
439
- if use_arxiv:
440
- search_tasks.append(self.search_arxiv(session, query, citations_per_query))
441
- if use_crossref:
442
- search_tasks.append(self.search_crossref(session, query, citations_per_query))
443
 
444
- results = await asyncio.gather(*search_tasks, return_exceptions=True)
445
-
446
- papers = []
447
- for r in results:
448
- if not isinstance(r, Exception):
449
- papers.extend(r)
450
-
451
- unique_papers = []
452
- seen_keys = set()
453
- for p in papers:
454
- if p['bibtex_key'] not in seen_keys:
455
- seen_keys.add(p['bibtex_key'])
456
- unique_papers.append(p)
457
- papers = unique_papers
458
-
459
- if not papers:
460
- return text, ""
 
 
 
 
 
461
 
462
- try:
463
- cited_text = await self.citation_chain.ainvoke({
464
- "text": text,
465
- "papers": json.dumps(papers, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  })
 
 
 
 
 
 
 
467
 
468
- # Use bibtexparser to aggregate BibTeX entries
469
- bib_database = BibDatabase()
470
- for p in papers:
471
- if 'bibtex_entry' in p:
472
- bib_db = bibtexparser.loads(p['bibtex_entry'])
473
- if bib_db.entries:
474
- bib_database.entries.append(bib_db.entries[0])
475
- else:
476
- logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
477
- writer = BibTexWriter()
478
- writer.indent = ' '
479
- writer.comma_first = False
480
- bibtex_entries = writer.write(bib_database).strip()
481
-
482
- return cited_text, bibtex_entries
483
- except Exception as e:
484
- logger.error(f"Error inserting citations: {e}") # Replace print with logger
485
- return text, ""
486
 
487
  def create_gradio_interface() -> gr.Interface:
488
- # Removed CitationGenerator initialization here
489
  async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
490
  use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
491
  if not api_key.strip():
@@ -600,12 +589,11 @@ def create_gradio_interface() -> gr.Interface:
600
 
601
  with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
602
  gr.HTML("""<div class="header">
603
- <h1>📚 AutoCitation</h1>
604
- <p>Insert citations into your academic text</p>
605
- </div>""")
606
 
607
  with gr.Group(elem_classes="input-group"):
608
- # Added API Key input field
609
  api_key = gr.Textbox(
610
  label="Gemini API Key",
611
  placeholder="Enter your Gemini API key...",
@@ -623,7 +611,7 @@ def create_gradio_interface() -> gr.Interface:
623
  label="Search Queries",
624
  value=3,
625
  minimum=1,
626
- maximum=5, # Changed to config.max_queries as 5
627
  step=1
628
  )
629
  with gr.Column(scale=1):
@@ -631,7 +619,7 @@ def create_gradio_interface() -> gr.Interface:
631
  label="Citations per Query",
632
  value=1,
633
  minimum=1,
634
- maximum=10, # Changed to config.max_citations_per_query as 10
635
  step=1
636
  )
637
 
@@ -669,7 +657,6 @@ def create_gradio_interface() -> gr.Interface:
669
  show_copy_button=True
670
  )
671
 
672
- # Updated the inputs and outputs
673
  process_btn.click(
674
  fn=process,
675
  inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
@@ -679,6 +666,10 @@ def create_gradio_interface() -> gr.Interface:
679
  return demo
680
 
681
  if __name__ == "__main__":
682
-
683
  demo = create_gradio_interface()
684
- demo.launch(server_port=7860, share=False)
 
 
 
 
 
 
11
 
12
  import aiohttp
13
  import gradio as gr
14
+
15
+ from langchain.prompts import PromptTemplate
16
+
17
  from langchain_google_genai import ChatGoogleGenerativeAI
18
 
19
  import bibtexparser
 
32
  default_headers: dict = field(default_factory=lambda: {
33
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
34
  })
35
+ log_level: str = 'DEBUG'
36
 
37
  class ArxivXmlParser:
38
  NS = {
 
111
  'year': year
112
  }]
113
  writer = BibTexWriter()
114
+ writer.indent = ' '
115
+ writer.comma_first = False
116
  return writer.write(db).strip()
117
 
118
  class AsyncContextManager:
 
135
  google_api_key=config.gemini_api_key,
136
  streaming=True
137
  )
138
+ self.citation_prompt = PromptTemplate.from_template(
 
 
 
 
 
 
139
  """Insert citations into the provided text using LaTeX \\cite{{key}} commands.
140
 
141
  You must not alter the original wording or structure of the text beyond adding citations.
 
148
  {papers}
149
  """
150
  )
151
+
152
+ self.generate_queries_prompt = PromptTemplate.from_template(
 
 
 
 
 
 
 
153
  """Generate {num_queries} diverse academic search queries based on the given text.
154
  The queries should be concise and relevant.
155
 
 
161
  Text: {text}
162
  """
163
  )
164
+
165
+ logger.remove()
166
+ logger.add(sys.stderr, level=config.log_level)
 
 
 
167
 
168
  async def generate_queries(self, text: str, num_queries: int) -> List[str]:
169
+ input_map = {
170
+ "text": text,
171
+ "num_queries": num_queries
172
+ }
173
  try:
174
+ prompt = self.generate_queries_prompt.format(**input_map)
175
+ response = await self.llm.apredict(prompt)
 
 
 
176
  content = response.strip()
177
  if not content.startswith('['):
178
  start = content.find('[')
 
191
  return ["deep learning neural networks"]
192
 
193
  except Exception as e:
194
+ logger.error(f"Error generating queries: {e}")
195
  return ["deep learning neural networks"]
196
 
197
  async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
 
203
  'sortBy': 'relevance',
204
  'sortOrder': 'descending'
205
  }
 
206
  async with session.get(
207
  self.config.arxiv_base_url + urllib.parse.urlencode(params),
208
  headers=self.config.default_headers,
 
218
  async def fix_author_name(self, author: str) -> str:
219
  if not re.search(r'[�]', author):
220
  return author
 
221
  try:
222
  prompt = f"""Fix this author name that contains corrupted characters (�):
223
 
 
227
  1. Return ONLY the fixed author name
228
  2. Use proper diacritical marks for names
229
  3. Consider common name patterns and languages
230
+ 4. If unsure, use the most likely letter
231
  5. Maintain the format: "Lastname, Firstname"
 
 
 
 
232
  """
233
+ response = await self.llm.apredict(prompt)
234
+ fixed_name = response.strip()
 
235
  return fixed_name if fixed_name else author
 
236
  except Exception as e:
237
  logger.error(f"Error fixing author name: {e}")
238
  return author
 
253
  writer.comma_first = False
254
  return writer.write(bib_database).strip()
255
  except Exception as e:
256
+ logger.error(f"Error cleaning BibTeX special characters: {e}")
257
  return text
258
 
259
  async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
 
318
 
319
  bibtex_text = await bibtex_response.text()
320
 
 
321
  bib_database = bibtexparser.loads(bibtex_text)
322
  if not bib_database.entries:
323
  continue
324
  entry = bib_database.entries[0]
325
 
 
326
  if 'title' not in entry and 'booktitle' not in entry:
327
+ continue
 
 
328
  if 'author' not in entry:
329
+ continue
330
 
 
331
  title = entry.get('title', 'No Title').replace('{', '').replace('}', '')
332
  authors = entry.get('author', 'Unknown').replace('\n', ' ').replace('\t', ' ').strip()
333
  year = entry.get('year', 'Unknown')
334
 
 
335
  key = self._generate_unique_bibtex_key(entry, existing_keys)
336
  entry['ID'] = key
337
  existing_keys.add(key)
338
 
 
339
  writer = BibTexWriter()
340
  writer.indent = ' '
341
  writer.comma_first = False
 
348
  'bibtex_key': key,
349
  'bibtex_entry': formatted_bibtex
350
  })
 
 
 
351
 
352
+ except Exception as e:
353
+ logger.error(f"Error processing CrossRef item: {e}")
354
  return papers
355
 
356
  except aiohttp.ClientError as e:
 
362
  await asyncio.sleep(delay)
363
 
364
  except Exception as e:
365
+ logger.error(f"Error searching CrossRef: {e}")
366
  return []
367
 
368
  def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
 
371
  year = entry.get('year', '')
372
  authors = [a.strip() for a in author_field.split(' and ')]
373
  first_author_last_name = authors[0].split(',')[0] if authors else 'unknown'
374
+
375
  if entry_type == 'inbook':
 
376
  booktitle = entry.get('booktitle', '')
377
  title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
378
  else:
 
379
  title = entry.get('title', '')
380
  title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
381
+
382
  base_key = f"{first_author_last_name}{year}{title_word}"
 
383
  key = base_key
384
  index = 1
385
  while key in existing_keys:
 
388
  return key
389
 
390
  async def process_text(self, text: str, num_queries: int, citations_per_query: int,
391
+ use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str]:
392
  if not (use_arxiv or use_crossref):
393
  return "Please select at least one source (ArXiv or CrossRef)", ""
394
 
395
  num_queries = min(max(1, num_queries), self.config.max_queries)
396
  citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
397
 
398
+ async def generate_queries_tool(input_data: dict):
399
+ return await self.generate_queries(input_data["text"], input_data["num_queries"])
 
 
 
 
 
 
 
 
 
400
 
401
+ async def search_papers_tool(input_data: dict):
402
+ queries = input_data["queries"]
403
+ papers = []
404
+ async with self.async_context as session:
405
+ search_tasks = []
406
+ for q in queries:
407
+ if input_data["use_arxiv"]:
408
+ search_tasks.append(self.search_arxiv(session, q, input_data["citations_per_query"]))
409
+ if input_data["use_crossref"]:
410
+ search_tasks.append(self.search_crossref(session, q, input_data["citations_per_query"]))
411
+ results = await asyncio.gather(*search_tasks, return_exceptions=True)
412
+ for r in results:
413
+ if not isinstance(r, Exception):
414
+ papers.extend(r)
415
+ # Deduplicate
416
+ unique_papers = []
417
+ seen_keys = set()
418
+ for p in papers:
419
+ if p['bibtex_key'] not in seen_keys:
420
+ seen_keys.add(p['bibtex_key'])
421
+ unique_papers.append(p)
422
+ return unique_papers
423
 
424
+ async def cite_text_tool(input_data: dict):
425
+ try:
426
+ citation_input = {
427
+ "text": input_data["text"],
428
+ "papers": json.dumps(input_data["papers"], indent=2)
429
+ }
430
+ prompt = self.citation_prompt.format(**citation_input)
431
+ response = await self.llm.apredict(prompt)
432
+ cited_text = response.strip()
433
+
434
+ # Aggregate BibTeX entries
435
+ bib_database = BibDatabase()
436
+ for p in input_data["papers"]:
437
+ if 'bibtex_entry' in p:
438
+ bib_db = bibtexparser.loads(p['bibtex_entry'])
439
+ if bib_db.entries:
440
+ bib_database.entries.append(bib_db.entries[0])
441
+ else:
442
+ logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
443
+ writer = BibTexWriter()
444
+ writer.indent = ' '
445
+ writer.comma_first = False
446
+ bibtex_entries = writer.write(bib_database).strip()
447
+ return cited_text, bibtex_entries
448
+ except Exception as e:
449
+ logger.error(f"Error inserting citations: {e}")
450
+ return input_data["text"], ""
451
+
452
+ async def agent_run(input_data: dict):
453
+ queries = await generate_queries_tool(input_data)
454
+ papers = await search_papers_tool({
455
+ "queries": queries,
456
+ "citations_per_query": input_data["citations_per_query"],
457
+ "use_arxiv": input_data["use_arxiv"],
458
+ "use_crossref": input_data["use_crossref"]
459
  })
460
+ if not papers:
461
+ return input_data["text"], ""
462
+ cited_text, final_bibtex = await cite_text_tool({
463
+ "text": input_data["text"],
464
+ "papers": papers
465
+ })
466
+ return cited_text, final_bibtex
467
 
468
+ final_text, final_bibtex = await agent_run({
469
+ "text": text,
470
+ "num_queries": num_queries,
471
+ "citations_per_query": citations_per_query,
472
+ "use_arxiv": use_arxiv,
473
+ "use_crossref": use_crossref
474
+ })
475
+ return final_text, final_bibtex
 
 
 
 
 
 
 
 
 
 
476
 
477
  def create_gradio_interface() -> gr.Interface:
 
478
  async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
479
  use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
480
  if not api_key.strip():
 
589
 
590
  with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
591
  gr.HTML("""<div class="header">
592
+ <h1>📚 AutoCitation</h1>
593
+ <p>Insert citations into your academic text</p>
594
+ </div>""")
595
 
596
  with gr.Group(elem_classes="input-group"):
 
597
  api_key = gr.Textbox(
598
  label="Gemini API Key",
599
  placeholder="Enter your Gemini API key...",
 
611
  label="Search Queries",
612
  value=3,
613
  minimum=1,
614
+ maximum=Config.max_queries,
615
  step=1
616
  )
617
  with gr.Column(scale=1):
 
619
  label="Citations per Query",
620
  value=1,
621
  minimum=1,
622
+ maximum=Config.max_citations_per_query,
623
  step=1
624
  )
625
 
 
657
  show_copy_button=True
658
  )
659
 
 
660
  process_btn.click(
661
  fn=process,
662
  inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
 
666
  return demo
667
 
668
  if __name__ == "__main__":
 
669
  demo = create_gradio_interface()
670
+ try:
671
+ demo.launch(server_port=7860, share=False)
672
+ except KeyboardInterrupt:
673
+ print("\nShutting down server...")
674
+ except Exception as e:
675
+ print(f"Error starting server: {str(e)}")