yipengsun commited on
Commit
c8a8531
·
verified ·
1 Parent(s): 9b14c03

Update autocitation.py

Browse files
Files changed (1) hide show
  1. autocitation.py +684 -689
autocitation.py CHANGED
@@ -1,689 +1,684 @@
1
- import asyncio
2
- import json
3
- import os
4
- import urllib.parse
5
- import re
6
- import xml.etree.ElementTree as ET
7
- from dataclasses import dataclass, field
8
- from typing import Dict, List, Optional
9
- import sys
10
- from loguru import logger
11
-
12
- import aiohttp
13
- import gradio as gr
14
- from langchain_core.output_parsers import StrOutputParser
15
- from langchain_core.prompts import PromptTemplate
16
- from langchain_core.runnables import RunnablePassthrough
17
- from langchain_google_genai import ChatGoogleGenerativeAI
18
-
19
- import bibtexparser
20
- from bibtexparser.bwriter import BibTexWriter
21
- from bibtexparser.bibdatabase import BibDatabase
22
-
23
- @dataclass
24
- class Config:
25
- gemini_api_key: str
26
- max_retries: int = 3
27
- base_delay: int = 1
28
- max_queries: int = 5
29
- max_citations_per_query: int = 10
30
- arxiv_base_url: str = 'http://export.arxiv.org/api/query?'
31
- crossref_base_url: str = 'https://api.crossref.org/works'
32
- default_headers: dict = field(default_factory=lambda: {
33
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
34
- })
35
- log_level: str = 'DEBUG' # Add log level configuration
36
-
37
- class ArxivXmlParser:
38
- NS = {
39
- 'atom': 'http://www.w3.org/2005/Atom',
40
- 'arxiv': 'http://arxiv.org/schemas/atom'
41
- }
42
-
43
- def parse_papers(self, data: str) -> List[Dict]:
44
- try:
45
- root = ET.fromstring(data)
46
- papers = []
47
- for entry in root.findall('atom:entry', self.NS):
48
- paper = self.parse_entry(entry)
49
- if paper:
50
- papers.append(paper)
51
- return papers
52
- except Exception as e:
53
- print(f"Error parsing ArXiv XML: {e}")
54
- return []
55
-
56
- def parse_entry(self, entry) -> Optional[dict]:
57
- try:
58
- title_node = entry.find('atom:title', self.NS)
59
- if title_node is None:
60
- return None
61
- title = title_node.text.strip()
62
-
63
- authors = []
64
- for author in entry.findall('atom:author', self.NS):
65
- author_name_node = author.find('atom:name', self.NS)
66
- if author_name_node is not None and author_name_node.text:
67
- authors.append(self._format_author_name(author_name_node.text.strip()))
68
-
69
- arxiv_id_node = entry.find('atom:id', self.NS)
70
- if arxiv_id_node is None:
71
- return None
72
- arxiv_id = arxiv_id_node.text.split('/')[-1]
73
-
74
- published_node = entry.find('atom:published', self.NS)
75
- year = published_node.text[:4] if published_node is not None else "Unknown"
76
-
77
- abstract_node = entry.find('atom:summary', self.NS)
78
- abstract = abstract_node.text.strip() if abstract_node is not None else ""
79
-
80
- bibtex_key = f"{authors[0].split(',')[0]}{arxiv_id.replace('.', '')}" if authors else f"unknown{arxiv_id.replace('.', '')}"
81
- bibtex_entry = self._generate_bibtex_entry(bibtex_key, title, authors, arxiv_id, year)
82
-
83
- return {
84
- 'title': title,
85
- 'authors': authors,
86
- 'arxiv_id': arxiv_id,
87
- 'published': year,
88
- 'abstract': abstract,
89
- 'bibtex_key': bibtex_key,
90
- 'bibtex_entry': bibtex_entry
91
- }
92
- except Exception as e:
93
- print(f"Error parsing ArXiv entry: {e}")
94
- return None
95
-
96
- @staticmethod
97
- def _format_author_name(author: str) -> str:
98
- names = author.split()
99
- if len(names) > 1:
100
- return f"{names[-1]}, {' '.join(names[:-1])}"
101
- return author
102
-
103
- def _generate_bibtex_entry(self, key: str, title: str, authors: List[str], arxiv_id: str, year: str) -> str:
104
- db = BibDatabase()
105
- db.entries = [{
106
- 'ENTRYTYPE': 'article',
107
- 'ID': key,
108
- 'title': title,
109
- 'author': ' and '.join(authors),
110
- 'journal': f'arXiv preprint arXiv:{arxiv_id}',
111
- 'year': year
112
- }]
113
- writer = BibTexWriter()
114
- writer.indent = ' ' # Indentation for entries
115
- writer.comma_first = False # Place the comma at the end of lines
116
- return writer.write(db).strip()
117
-
118
- class AsyncContextManager:
119
- async def __aenter__(self):
120
- self._session = aiohttp.ClientSession()
121
- return self._session
122
-
123
- async def __aexit__(self, *_):
124
- if self._session:
125
- await self._session.close()
126
-
127
- class CitationGenerator:
128
- def __init__(self, config: Config):
129
- self.config = config
130
- self.xml_parser = ArxivXmlParser()
131
- self.async_context = AsyncContextManager()
132
- self.llm = ChatGoogleGenerativeAI(
133
- model="gemini-1.5-flash",
134
- temperature=0.3,
135
- google_api_key=config.gemini_api_key,
136
- streaming=True
137
- )
138
- self.citation_chain = self._create_citation_chain()
139
- self.generate_queries_chain = self._create_generate_queries_chain()
140
- logger.remove()
141
- logger.add(sys.stderr, level=config.log_level) # Configure logger
142
-
143
- def _create_citation_chain(self):
144
- citation_prompt = PromptTemplate.from_template(
145
- """Insert citations into the provided text using LaTeX \\cite{{key}} commands.
146
-
147
- You must not alter the original wording or structure of the text beyond adding citations.
148
- You must include all provided references at least once. Place citations at suitable points.
149
-
150
- Input text:
151
- {text}
152
-
153
- Available papers (cite each at least once):
154
- {papers}
155
- """
156
- )
157
- return (
158
- {"text": RunnablePassthrough(), "papers": RunnablePassthrough()}
159
- | citation_prompt
160
- | self.llm
161
- | StrOutputParser()
162
- )
163
-
164
- def _create_generate_queries_chain(self):
165
- generate_queries_prompt = PromptTemplate.from_template(
166
- """Generate {num_queries} diverse academic search queries based on the given text.
167
- The queries should be concise and relevant.
168
-
169
- Requirements:
170
- 1. Return ONLY a valid JSON array of strings.
171
- 2. No additional text or formatting beyond JSON.
172
- 3. Ensure uniqueness.
173
-
174
- Text: {text}
175
- """
176
- )
177
- return (
178
- {"text": RunnablePassthrough(), "num_queries": RunnablePassthrough()}
179
- | generate_queries_prompt
180
- | self.llm
181
- | StrOutputParser()
182
- )
183
-
184
- async def generate_queries(self, text: str, num_queries: int) -> List[str]:
185
- try:
186
- response = await self.generate_queries_chain.ainvoke({
187
- "text": text,
188
- "num_queries": num_queries
189
- })
190
-
191
- content = response.strip()
192
- if not content.startswith('['):
193
- start = content.find('[')
194
- end = content.rfind(']') + 1
195
- if start >= 0 and end > start:
196
- content = content[start:end]
197
- try:
198
- queries = json.loads(content)
199
- if isinstance(queries, list):
200
- return [q.strip() for q in queries if isinstance(q, str)][:num_queries]
201
- except json.JSONDecodeError:
202
- lines = [line.strip() for line in content.split('\n')
203
- if line.strip() and not line.strip().startswith(('[', ']'))]
204
- return lines[:num_queries]
205
-
206
- return ["deep learning neural networks"]
207
-
208
- except Exception as e:
209
- logger.error(f"Error generating queries: {e}") # Replace print with logger
210
- return ["deep learning neural networks"]
211
-
212
- async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
213
- try:
214
- params = {
215
- 'search_query': f'all:{urllib.parse.quote(query)}',
216
- 'start': 0,
217
- 'max_results': max_results,
218
- 'sortBy': 'relevance',
219
- 'sortOrder': 'descending'
220
- }
221
-
222
- async with session.get(
223
- self.config.arxiv_base_url + urllib.parse.urlencode(params),
224
- headers=self.config.default_headers,
225
- timeout=30
226
- ) as response:
227
- text_data = await response.text()
228
- papers = self.xml_parser.parse_papers(text_data)
229
- return papers
230
- except Exception as e:
231
- logger.error(f"Error searching ArXiv: {e}")
232
- return []
233
-
234
- async def fix_author_name(self, author: str) -> str:
235
- if not re.search(r'[�]', author):
236
- return author
237
-
238
- try:
239
- prompt = f"""Fix this author name that contains corrupted characters (�):
240
-
241
- Name: {author}
242
-
243
- Requirements:
244
- 1. Return ONLY the fixed author name
245
- 2. Use proper diacritical marks for names
246
- 3. Consider common name patterns and languages
247
- 4. If unsure about a character, use the most likely letter
248
- 5. Maintain the format: "Lastname, Firstname"
249
-
250
- Example fixes:
251
- - "Gonz�lez" -> "González"
252
- - "Cristi�n" -> "Cristián"
253
- """
254
-
255
- response = await self.llm.ainvoke([{"role": "user", "content": prompt}])
256
- fixed_name = response.content.strip()
257
- return fixed_name if fixed_name else author
258
-
259
- except Exception as e:
260
- logger.error(f"Error fixing author name: {e}")
261
- return author
262
-
263
- async def format_bibtex_author_names(self, text: str) -> str:
264
- try:
265
- bib_database = bibtexparser.loads(text)
266
- for entry in bib_database.entries:
267
- if 'author' in entry:
268
- authors = entry['author'].split(' and ')
269
- cleaned_authors = []
270
- for author in authors:
271
- fixed_author = await self.fix_author_name(author)
272
- cleaned_authors.append(fixed_author)
273
- entry['author'] = ' and '.join(cleaned_authors)
274
- writer = BibTexWriter()
275
- writer.indent = ' '
276
- writer.comma_first = False
277
- return writer.write(bib_database).strip()
278
- except Exception as e:
279
- logger.error(f"Error cleaning BibTeX special characters: {e}")
280
- return text
281
-
282
- async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
283
- try:
284
- cleaned_query = query.replace("'", "").replace('"', "")
285
- if ' ' in cleaned_query:
286
- cleaned_query = f'"{cleaned_query}"'
287
-
288
- params = {
289
- 'query.bibliographic': cleaned_query,
290
- 'rows': max_results,
291
- 'select': 'DOI,title,author,published-print,container-title',
292
- 'sort': 'relevance',
293
- 'order': 'desc'
294
- }
295
-
296
- headers = {
297
- 'User-Agent': 'Mozilla/5.0 (compatible; CitationBot/1.0; mailto:[email protected])',
298
- 'Accept': 'application/json'
299
- }
300
-
301
- for attempt in range(self.config.max_retries):
302
- try:
303
- async with session.get(
304
- self.config.crossref_base_url,
305
- params=params,
306
- headers=headers,
307
- timeout=30
308
- ) as response:
309
- if response.status == 429:
310
- delay = self.config.base_delay * (2 ** attempt)
311
- logger.warning(f"Rate limited by CrossRef. Retrying in {delay} seconds...")
312
- await asyncio.sleep(delay)
313
- continue
314
-
315
- response.raise_for_status()
316
- search_data = await response.json()
317
- items = search_data.get('message', {}).get('items', [])
318
-
319
- if not items:
320
- return []
321
-
322
- papers = []
323
- existing_keys = set()
324
- for item in items:
325
- doi = item.get('DOI')
326
- if not doi:
327
- continue
328
-
329
- try:
330
- bibtex_url = f"https://doi.org/{doi}"
331
- async with session.get(
332
- bibtex_url,
333
- headers={
334
- 'Accept': 'application/x-bibtex',
335
- 'User-Agent': 'Mozilla/5.0 (compatible; CitationBot/1.0; mailto:[email protected])'
336
- },
337
- timeout=30
338
- ) as bibtex_response:
339
- if bibtex_response.status != 200:
340
- continue
341
-
342
- bibtex_text = await bibtex_response.text()
343
-
344
- # Parse the BibTeX entry
345
- bib_database = bibtexparser.loads(bibtex_text)
346
- if not bib_database.entries:
347
- continue
348
- entry = bib_database.entries[0]
349
-
350
- # Check if 'title' or 'booktitle' is present
351
- if 'title' not in entry and 'booktitle' not in entry:
352
- continue # Skip entries without 'title' or 'booktitle'
353
-
354
- # Check if 'author' is present
355
- if 'author' not in entry:
356
- continue # Skip entries without 'author'
357
-
358
- # Extract necessary fields
359
- title = entry.get('title', 'No Title').replace('{', '').replace('}', '')
360
- authors = entry.get('author', 'Unknown').replace('\n', ' ').replace('\t', ' ').strip()
361
- year = entry.get('year', 'Unknown')
362
-
363
- # Generate a unique BibTeX key
364
- key = self._generate_unique_bibtex_key(entry, existing_keys)
365
- entry['ID'] = key
366
- existing_keys.add(key)
367
-
368
- # Use BibTexWriter to format the entry
369
- writer = BibTexWriter()
370
- writer.indent = ' '
371
- writer.comma_first = False
372
- formatted_bibtex = writer.write(bib_database).strip()
373
-
374
- papers.append({
375
- 'title': title,
376
- 'authors': authors,
377
- 'year': year,
378
- 'bibtex_key': key,
379
- 'bibtex_entry': formatted_bibtex
380
- })
381
- except Exception as e:
382
- logger.error(f"Error processing CrossRef item: {e}") # Replace print with logger
383
- continue
384
-
385
- return papers
386
-
387
- except aiohttp.ClientError as e:
388
- if attempt == self.config.max_retries - 1:
389
- logger.error(f"Max retries reached for CrossRef search. Error: {e}")
390
- raise
391
- delay = self.config.base_delay * (2 ** attempt)
392
- logger.warning(f"Client error during CrossRef search: {e}. Retrying in {delay} seconds...")
393
- await asyncio.sleep(delay)
394
-
395
- except Exception as e:
396
- logger.error(f"Error searching CrossRef: {e}") # Replace print with logger
397
- return []
398
-
399
- def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
400
- entry_type = entry.get('ENTRYTYPE', '').lower()
401
- author_field = entry.get('author', '')
402
- year = entry.get('year', '')
403
- authors = [a.strip() for a in author_field.split(' and ')]
404
- first_author_last_name = authors[0].split(',')[0] if authors else 'unknown'
405
-
406
- if entry_type == 'inbook':
407
- # Use 'booktitle' for 'inbook' entries
408
- booktitle = entry.get('booktitle', '')
409
- title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
410
- else:
411
- # Use regular 'title' for other entries
412
- title = entry.get('title', '')
413
- title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
414
-
415
- base_key = f"{first_author_last_name}{year}{title_word}"
416
- # Ensure the key is unique
417
- key = base_key
418
- index = 1
419
- while key in existing_keys:
420
- key = f"{base_key}{index}"
421
- index += 1
422
- return key
423
-
424
- async def process_text(self, text: str, num_queries: int, citations_per_query: int,
425
- use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str]:
426
- if not (use_arxiv or use_crossref):
427
- return "Please select at least one source (ArXiv or CrossRef)", ""
428
-
429
- num_queries = min(max(1, num_queries), self.config.max_queries)
430
- citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
431
-
432
- queries = await self.generate_queries(text, num_queries)
433
- if not queries:
434
- return text, ""
435
-
436
- async with self.async_context as session:
437
- search_tasks = []
438
- for query in queries:
439
- if use_arxiv:
440
- search_tasks.append(self.search_arxiv(session, query, citations_per_query))
441
- if use_crossref:
442
- search_tasks.append(self.search_crossref(session, query, citations_per_query))
443
-
444
- results = await asyncio.gather(*search_tasks, return_exceptions=True)
445
-
446
- papers = []
447
- for r in results:
448
- if not isinstance(r, Exception):
449
- papers.extend(r)
450
-
451
- unique_papers = []
452
- seen_keys = set()
453
- for p in papers:
454
- if p['bibtex_key'] not in seen_keys:
455
- seen_keys.add(p['bibtex_key'])
456
- unique_papers.append(p)
457
- papers = unique_papers
458
-
459
- if not papers:
460
- return text, ""
461
-
462
- try:
463
- cited_text = await self.citation_chain.ainvoke({
464
- "text": text,
465
- "papers": json.dumps(papers, indent=2)
466
- })
467
-
468
- # Use bibtexparser to aggregate BibTeX entries
469
- bib_database = BibDatabase()
470
- for p in papers:
471
- if 'bibtex_entry' in p:
472
- bib_db = bibtexparser.loads(p['bibtex_entry'])
473
- if bib_db.entries:
474
- bib_database.entries.append(bib_db.entries[0])
475
- else:
476
- logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
477
- writer = BibTexWriter()
478
- writer.indent = ' '
479
- writer.comma_first = False
480
- bibtex_entries = writer.write(bib_database).strip()
481
-
482
- return cited_text, bibtex_entries
483
- except Exception as e:
484
- logger.error(f"Error inserting citations: {e}") # Replace print with logger
485
- return text, ""
486
-
487
- def create_gradio_interface() -> gr.Interface:
488
- # Removed CitationGenerator initialization here
489
- async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
490
- use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
491
- if not api_key.strip():
492
- return "Please enter your Gemini API Key.", ""
493
- if not text.strip():
494
- return "Please enter text to process", ""
495
- try:
496
- config = Config(
497
- gemini_api_key=api_key
498
- )
499
- citation_gen = CitationGenerator(config)
500
- return await citation_gen.process_text(
501
- text, num_queries, citations_per_query,
502
- use_arxiv=use_arxiv, use_crossref=use_crossref
503
- )
504
- except ValueError as e:
505
- return f"Input validation error: {str(e)}", ""
506
- except Exception as e:
507
- return f"Error: {str(e)}", ""
508
-
509
- css = """
510
- :root {
511
- --primary: #6A7E76;
512
- --primary-hover: #566961;
513
- --bg: #FFFFFF;
514
- --text: #454442;
515
- --border: #B4B0AC;
516
- --control-bg: #F5F3F0;
517
- }
518
-
519
- .container, .header, .input-group, .controls-row {
520
- padding: 0.75rem;
521
- }
522
-
523
- .container {
524
- max-width: 100%;
525
- background: var(--bg);
526
- }
527
-
528
- .header {
529
- text-align: center;
530
- margin-bottom: 1rem;
531
- background: var(--bg);
532
- border-bottom: 1px solid var(--border);
533
- }
534
-
535
- .header h1 {
536
- font-size: 1.5rem;
537
- color: var(--primary);
538
- font-weight: 500;
539
- margin-bottom: 0.25rem;
540
- }
541
-
542
- .header p, label span {
543
- font-size: 0.9rem;
544
- color: var(--text);
545
- }
546
-
547
- .input-group {
548
- border-radius: 4px;
549
- border: 1px solid var(--border);
550
- margin-bottom: 0.75rem;
551
- }
552
-
553
- .controls-row {
554
- display: flex !important;
555
- gap: 0.75rem;
556
- margin-top: 0.5rem;
557
- }
558
-
559
- .source-controls {
560
- display: flex;
561
- gap: 0.75rem;
562
- margin-top: 0.5rem;
563
- }
564
-
565
- .checkbox-group {
566
- display: flex;
567
- align-items: center;
568
- gap: 0.5rem;
569
- }
570
-
571
- input[type="number"], textarea {
572
- border: 1px solid var(--border);
573
- border-radius: 4px;
574
- padding: 0.5rem;
575
- background: var(--control-bg);
576
- color: var(--text);
577
- font-size: 0.95rem;
578
- }
579
-
580
- .generate-btn {
581
- background: var(--primary);
582
- color: white;
583
- padding: 0.5rem 1.5rem;
584
- border-radius: 4px;
585
- border: none;
586
- font-size: 0.9rem;
587
- transition: background 0.2s;
588
- width: 100%;
589
- }
590
-
591
- .generate-btn:hover {
592
- background: var(--primary-hover);
593
- }
594
-
595
- .output-container {
596
- display: flex;
597
- gap: 0.75rem;
598
- }
599
- """
600
-
601
- with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
602
- gr.HTML("""<div class="header">
603
- <h1>📚 AutoCitation</h1>
604
- <p>Insert citations into your academic text</p>
605
- </div>""")
606
-
607
- with gr.Group(elem_classes="input-group"):
608
- # Added API Key input field
609
- api_key = gr.Textbox(
610
- label="Gemini API Key",
611
- placeholder="Enter your Gemini API key...",
612
- type="password",
613
- interactive=True
614
- )
615
- input_text = gr.Textbox(
616
- label="Input Text",
617
- placeholder="Paste or type your text here...",
618
- lines=8
619
- )
620
- with gr.Row(elem_classes="controls-row"):
621
- with gr.Column(scale=1):
622
- num_queries = gr.Number(
623
- label="Search Queries",
624
- value=3,
625
- minimum=1,
626
- maximum=5, # Changed to config.max_queries as 5
627
- step=1
628
- )
629
- with gr.Column(scale=1):
630
- citations_per_query = gr.Number(
631
- label="Citations per Query",
632
- value=1,
633
- minimum=1,
634
- maximum=10, # Changed to config.max_citations_per_query as 10
635
- step=1
636
- )
637
-
638
- with gr.Row(elem_classes="source-controls"):
639
- with gr.Column(scale=1):
640
- use_arxiv = gr.Checkbox(
641
- label="Search ArXiv",
642
- value=True,
643
- elem_classes="checkbox-group"
644
- )
645
- with gr.Column(scale=1):
646
- use_crossref = gr.Checkbox(
647
- label="Search CrossRef (Experimental)",
648
- value=True,
649
- elem_classes="checkbox-group"
650
- )
651
- with gr.Column(scale=2):
652
- process_btn = gr.Button(
653
- "Generate",
654
- elem_classes="generate-btn"
655
- )
656
-
657
- with gr.Group(elem_classes="output-group"):
658
- with gr.Row():
659
- with gr.Column(scale=1):
660
- cited_text = gr.Textbox(
661
- label="Generated Text",
662
- lines=10,
663
- show_copy_button=True
664
- )
665
- with gr.Column(scale=1):
666
- bibtex = gr.Textbox(
667
- label="BibTeX References",
668
- lines=10,
669
- show_copy_button=True
670
- )
671
-
672
- # Updated the inputs and outputs
673
- process_btn.click(
674
- fn=process,
675
- inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
676
- outputs=[cited_text, bibtex]
677
- )
678
-
679
- return demo
680
-
681
- if __name__ == "__main__":
682
- # Removed environment variable loading and config initialization
683
- demo = create_gradio_interface()
684
- try:
685
- demo.launch(server_port=7860, share=False)
686
- except KeyboardInterrupt:
687
- print("\nShutting down server...")
688
- except Exception as e:
689
- print(f"Error starting server: {str(e)}")
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ import urllib.parse
5
+ import re
6
+ import xml.etree.ElementTree as ET
7
+ from dataclasses import dataclass, field
8
+ from typing import Dict, List, Optional
9
+ import sys
10
+ from loguru import logger
11
+
12
+ import aiohttp
13
+ import gradio as gr
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from langchain_core.prompts import PromptTemplate
16
+ from langchain_core.runnables import RunnablePassthrough
17
+ from langchain_google_genai import ChatGoogleGenerativeAI
18
+
19
+ import bibtexparser
20
+ from bibtexparser.bwriter import BibTexWriter
21
+ from bibtexparser.bibdatabase import BibDatabase
22
+
23
+ @dataclass
24
+ class Config:
25
+ gemini_api_key: str
26
+ max_retries: int = 3
27
+ base_delay: int = 1
28
+ max_queries: int = 5
29
+ max_citations_per_query: int = 10
30
+ arxiv_base_url: str = 'http://export.arxiv.org/api/query?'
31
+ crossref_base_url: str = 'https://api.crossref.org/works'
32
+ default_headers: dict = field(default_factory=lambda: {
33
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
34
+ })
35
+ log_level: str = 'DEBUG' # Add log level configuration
36
+
37
+ class ArxivXmlParser:
38
+ NS = {
39
+ 'atom': 'http://www.w3.org/2005/Atom',
40
+ 'arxiv': 'http://arxiv.org/schemas/atom'
41
+ }
42
+
43
+ def parse_papers(self, data: str) -> List[Dict]:
44
+ try:
45
+ root = ET.fromstring(data)
46
+ papers = []
47
+ for entry in root.findall('atom:entry', self.NS):
48
+ paper = self.parse_entry(entry)
49
+ if paper:
50
+ papers.append(paper)
51
+ return papers
52
+ except Exception as e:
53
+ print(f"Error parsing ArXiv XML: {e}")
54
+ return []
55
+
56
+ def parse_entry(self, entry) -> Optional[dict]:
57
+ try:
58
+ title_node = entry.find('atom:title', self.NS)
59
+ if title_node is None:
60
+ return None
61
+ title = title_node.text.strip()
62
+
63
+ authors = []
64
+ for author in entry.findall('atom:author', self.NS):
65
+ author_name_node = author.find('atom:name', self.NS)
66
+ if author_name_node is not None and author_name_node.text:
67
+ authors.append(self._format_author_name(author_name_node.text.strip()))
68
+
69
+ arxiv_id_node = entry.find('atom:id', self.NS)
70
+ if arxiv_id_node is None:
71
+ return None
72
+ arxiv_id = arxiv_id_node.text.split('/')[-1]
73
+
74
+ published_node = entry.find('atom:published', self.NS)
75
+ year = published_node.text[:4] if published_node is not None else "Unknown"
76
+
77
+ abstract_node = entry.find('atom:summary', self.NS)
78
+ abstract = abstract_node.text.strip() if abstract_node is not None else ""
79
+
80
+ bibtex_key = f"{authors[0].split(',')[0]}{arxiv_id.replace('.', '')}" if authors else f"unknown{arxiv_id.replace('.', '')}"
81
+ bibtex_entry = self._generate_bibtex_entry(bibtex_key, title, authors, arxiv_id, year)
82
+
83
+ return {
84
+ 'title': title,
85
+ 'authors': authors,
86
+ 'arxiv_id': arxiv_id,
87
+ 'published': year,
88
+ 'abstract': abstract,
89
+ 'bibtex_key': bibtex_key,
90
+ 'bibtex_entry': bibtex_entry
91
+ }
92
+ except Exception as e:
93
+ print(f"Error parsing ArXiv entry: {e}")
94
+ return None
95
+
96
+ @staticmethod
97
+ def _format_author_name(author: str) -> str:
98
+ names = author.split()
99
+ if len(names) > 1:
100
+ return f"{names[-1]}, {' '.join(names[:-1])}"
101
+ return author
102
+
103
+ def _generate_bibtex_entry(self, key: str, title: str, authors: List[str], arxiv_id: str, year: str) -> str:
104
+ db = BibDatabase()
105
+ db.entries = [{
106
+ 'ENTRYTYPE': 'article',
107
+ 'ID': key,
108
+ 'title': title,
109
+ 'author': ' and '.join(authors),
110
+ 'journal': f'arXiv preprint arXiv:{arxiv_id}',
111
+ 'year': year
112
+ }]
113
+ writer = BibTexWriter()
114
+ writer.indent = ' ' # Indentation for entries
115
+ writer.comma_first = False # Place the comma at the end of lines
116
+ return writer.write(db).strip()
117
+
118
+ class AsyncContextManager:
119
+ async def __aenter__(self):
120
+ self._session = aiohttp.ClientSession()
121
+ return self._session
122
+
123
+ async def __aexit__(self, *_):
124
+ if self._session:
125
+ await self._session.close()
126
+
127
+ class CitationGenerator:
128
+ def __init__(self, config: Config):
129
+ self.config = config
130
+ self.xml_parser = ArxivXmlParser()
131
+ self.async_context = AsyncContextManager()
132
+ self.llm = ChatGoogleGenerativeAI(
133
+ model="gemini-1.5-flash",
134
+ temperature=0.3,
135
+ google_api_key=config.gemini_api_key,
136
+ streaming=True
137
+ )
138
+ self.citation_chain = self._create_citation_chain()
139
+ self.generate_queries_chain = self._create_generate_queries_chain()
140
+ logger.remove()
141
+ logger.add(sys.stderr, level=config.log_level) # Configure logger
142
+
143
+ def _create_citation_chain(self):
144
+ citation_prompt = PromptTemplate.from_template(
145
+ """Insert citations into the provided text using LaTeX \\cite{{key}} commands.
146
+
147
+ You must not alter the original wording or structure of the text beyond adding citations.
148
+ You must include all provided references at least once. Place citations at suitable points.
149
+
150
+ Input text:
151
+ {text}
152
+
153
+ Available papers (cite each at least once):
154
+ {papers}
155
+ """
156
+ )
157
+ return (
158
+ {"text": RunnablePassthrough(), "papers": RunnablePassthrough()}
159
+ | citation_prompt
160
+ | self.llm
161
+ | StrOutputParser()
162
+ )
163
+
164
+ def _create_generate_queries_chain(self):
165
+ generate_queries_prompt = PromptTemplate.from_template(
166
+ """Generate {num_queries} diverse academic search queries based on the given text.
167
+ The queries should be concise and relevant.
168
+
169
+ Requirements:
170
+ 1. Return ONLY a valid JSON array of strings.
171
+ 2. No additional text or formatting beyond JSON.
172
+ 3. Ensure uniqueness.
173
+
174
+ Text: {text}
175
+ """
176
+ )
177
+ return (
178
+ {"text": RunnablePassthrough(), "num_queries": RunnablePassthrough()}
179
+ | generate_queries_prompt
180
+ | self.llm
181
+ | StrOutputParser()
182
+ )
183
+
184
+ async def generate_queries(self, text: str, num_queries: int) -> List[str]:
185
+ try:
186
+ response = await self.generate_queries_chain.ainvoke({
187
+ "text": text,
188
+ "num_queries": num_queries
189
+ })
190
+
191
+ content = response.strip()
192
+ if not content.startswith('['):
193
+ start = content.find('[')
194
+ end = content.rfind(']') + 1
195
+ if start >= 0 and end > start:
196
+ content = content[start:end]
197
+ try:
198
+ queries = json.loads(content)
199
+ if isinstance(queries, list):
200
+ return [q.strip() for q in queries if isinstance(q, str)][:num_queries]
201
+ except json.JSONDecodeError:
202
+ lines = [line.strip() for line in content.split('\n')
203
+ if line.strip() and not line.strip().startswith(('[', ']'))]
204
+ return lines[:num_queries]
205
+
206
+ return ["deep learning neural networks"]
207
+
208
+ except Exception as e:
209
+ logger.error(f"Error generating queries: {e}") # Replace print with logger
210
+ return ["deep learning neural networks"]
211
+
212
+ async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
213
+ try:
214
+ params = {
215
+ 'search_query': f'all:{urllib.parse.quote(query)}',
216
+ 'start': 0,
217
+ 'max_results': max_results,
218
+ 'sortBy': 'relevance',
219
+ 'sortOrder': 'descending'
220
+ }
221
+
222
+ async with session.get(
223
+ self.config.arxiv_base_url + urllib.parse.urlencode(params),
224
+ headers=self.config.default_headers,
225
+ timeout=30
226
+ ) as response:
227
+ text_data = await response.text()
228
+ papers = self.xml_parser.parse_papers(text_data)
229
+ return papers
230
+ except Exception as e:
231
+ logger.error(f"Error searching ArXiv: {e}")
232
+ return []
233
+
234
+ async def fix_author_name(self, author: str) -> str:
235
+ if not re.search(r'[�]', author):
236
+ return author
237
+
238
+ try:
239
+ prompt = f"""Fix this author name that contains corrupted characters (�):
240
+
241
+ Name: {author}
242
+
243
+ Requirements:
244
+ 1. Return ONLY the fixed author name
245
+ 2. Use proper diacritical marks for names
246
+ 3. Consider common name patterns and languages
247
+ 4. If unsure about a character, use the most likely letter
248
+ 5. Maintain the format: "Lastname, Firstname"
249
+
250
+ Example fixes:
251
+ - "Gonz�lez" -> "González"
252
+ - "Cristi�n" -> "Cristián"
253
+ """
254
+
255
+ response = await self.llm.ainvoke([{"role": "user", "content": prompt}])
256
+ fixed_name = response.content.strip()
257
+ return fixed_name if fixed_name else author
258
+
259
+ except Exception as e:
260
+ logger.error(f"Error fixing author name: {e}")
261
+ return author
262
+
263
+ async def format_bibtex_author_names(self, text: str) -> str:
264
+ try:
265
+ bib_database = bibtexparser.loads(text)
266
+ for entry in bib_database.entries:
267
+ if 'author' in entry:
268
+ authors = entry['author'].split(' and ')
269
+ cleaned_authors = []
270
+ for author in authors:
271
+ fixed_author = await self.fix_author_name(author)
272
+ cleaned_authors.append(fixed_author)
273
+ entry['author'] = ' and '.join(cleaned_authors)
274
+ writer = BibTexWriter()
275
+ writer.indent = ' '
276
+ writer.comma_first = False
277
+ return writer.write(bib_database).strip()
278
+ except Exception as e:
279
+ logger.error(f"Error cleaning BibTeX special characters: {e}")
280
+ return text
281
+
282
+ async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
283
+ try:
284
+ cleaned_query = query.replace("'", "").replace('"', "")
285
+ if ' ' in cleaned_query:
286
+ cleaned_query = f'"{cleaned_query}"'
287
+
288
+ params = {
289
+ 'query.bibliographic': cleaned_query,
290
+ 'rows': max_results,
291
+ 'select': 'DOI,title,author,published-print,container-title',
292
+ 'sort': 'relevance',
293
+ 'order': 'desc'
294
+ }
295
+
296
+ headers = {
297
+ 'User-Agent': 'Mozilla/5.0 (compatible; CitationBot/1.0; mailto:[email protected])',
298
+ 'Accept': 'application/json'
299
+ }
300
+
301
+ for attempt in range(self.config.max_retries):
302
+ try:
303
+ async with session.get(
304
+ self.config.crossref_base_url,
305
+ params=params,
306
+ headers=headers,
307
+ timeout=30
308
+ ) as response:
309
+ if response.status == 429:
310
+ delay = self.config.base_delay * (2 ** attempt)
311
+ logger.warning(f"Rate limited by CrossRef. Retrying in {delay} seconds...")
312
+ await asyncio.sleep(delay)
313
+ continue
314
+
315
+ response.raise_for_status()
316
+ search_data = await response.json()
317
+ items = search_data.get('message', {}).get('items', [])
318
+
319
+ if not items:
320
+ return []
321
+
322
+ papers = []
323
+ existing_keys = set()
324
+ for item in items:
325
+ doi = item.get('DOI')
326
+ if not doi:
327
+ continue
328
+
329
+ try:
330
+ bibtex_url = f"https://doi.org/{doi}"
331
+ async with session.get(
332
+ bibtex_url,
333
+ headers={
334
+ 'Accept': 'application/x-bibtex',
335
+ 'User-Agent': 'Mozilla/5.0 (compatible; CitationBot/1.0; mailto:[email protected])'
336
+ },
337
+ timeout=30
338
+ ) as bibtex_response:
339
+ if bibtex_response.status != 200:
340
+ continue
341
+
342
+ bibtex_text = await bibtex_response.text()
343
+
344
+ # Parse the BibTeX entry
345
+ bib_database = bibtexparser.loads(bibtex_text)
346
+ if not bib_database.entries:
347
+ continue
348
+ entry = bib_database.entries[0]
349
+
350
+ # Check if 'title' or 'booktitle' is present
351
+ if 'title' not in entry and 'booktitle' not in entry:
352
+ continue # Skip entries without 'title' or 'booktitle'
353
+
354
+ # Check if 'author' is present
355
+ if 'author' not in entry:
356
+ continue # Skip entries without 'author'
357
+
358
+ # Extract necessary fields
359
+ title = entry.get('title', 'No Title').replace('{', '').replace('}', '')
360
+ authors = entry.get('author', 'Unknown').replace('\n', ' ').replace('\t', ' ').strip()
361
+ year = entry.get('year', 'Unknown')
362
+
363
+ # Generate a unique BibTeX key
364
+ key = self._generate_unique_bibtex_key(entry, existing_keys)
365
+ entry['ID'] = key
366
+ existing_keys.add(key)
367
+
368
+ # Use BibTexWriter to format the entry
369
+ writer = BibTexWriter()
370
+ writer.indent = ' '
371
+ writer.comma_first = False
372
+ formatted_bibtex = writer.write(bib_database).strip()
373
+
374
+ papers.append({
375
+ 'title': title,
376
+ 'authors': authors,
377
+ 'year': year,
378
+ 'bibtex_key': key,
379
+ 'bibtex_entry': formatted_bibtex
380
+ })
381
+ except Exception as e:
382
+ logger.error(f"Error processing CrossRef item: {e}") # Replace print with logger
383
+ continue
384
+
385
+ return papers
386
+
387
+ except aiohttp.ClientError as e:
388
+ if attempt == self.config.max_retries - 1:
389
+ logger.error(f"Max retries reached for CrossRef search. Error: {e}")
390
+ raise
391
+ delay = self.config.base_delay * (2 ** attempt)
392
+ logger.warning(f"Client error during CrossRef search: {e}. Retrying in {delay} seconds...")
393
+ await asyncio.sleep(delay)
394
+
395
+ except Exception as e:
396
+ logger.error(f"Error searching CrossRef: {e}") # Replace print with logger
397
+ return []
398
+
399
+ def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
400
+ entry_type = entry.get('ENTRYTYPE', '').lower()
401
+ author_field = entry.get('author', '')
402
+ year = entry.get('year', '')
403
+ authors = [a.strip() for a in author_field.split(' and ')]
404
+ first_author_last_name = authors[0].split(',')[0] if authors else 'unknown'
405
+
406
+ if entry_type == 'inbook':
407
+ # Use 'booktitle' for 'inbook' entries
408
+ booktitle = entry.get('booktitle', '')
409
+ title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
410
+ else:
411
+ # Use regular 'title' for other entries
412
+ title = entry.get('title', '')
413
+ title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
414
+
415
+ base_key = f"{first_author_last_name}{year}{title_word}"
416
+ # Ensure the key is unique
417
+ key = base_key
418
+ index = 1
419
+ while key in existing_keys:
420
+ key = f"{base_key}{index}"
421
+ index += 1
422
+ return key
423
+
424
+ async def process_text(self, text: str, num_queries: int, citations_per_query: int,
425
+ use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str]:
426
+ if not (use_arxiv or use_crossref):
427
+ return "Please select at least one source (ArXiv or CrossRef)", ""
428
+
429
+ num_queries = min(max(1, num_queries), self.config.max_queries)
430
+ citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
431
+
432
+ queries = await self.generate_queries(text, num_queries)
433
+ if not queries:
434
+ return text, ""
435
+
436
+ async with self.async_context as session:
437
+ search_tasks = []
438
+ for query in queries:
439
+ if use_arxiv:
440
+ search_tasks.append(self.search_arxiv(session, query, citations_per_query))
441
+ if use_crossref:
442
+ search_tasks.append(self.search_crossref(session, query, citations_per_query))
443
+
444
+ results = await asyncio.gather(*search_tasks, return_exceptions=True)
445
+
446
+ papers = []
447
+ for r in results:
448
+ if not isinstance(r, Exception):
449
+ papers.extend(r)
450
+
451
+ unique_papers = []
452
+ seen_keys = set()
453
+ for p in papers:
454
+ if p['bibtex_key'] not in seen_keys:
455
+ seen_keys.add(p['bibtex_key'])
456
+ unique_papers.append(p)
457
+ papers = unique_papers
458
+
459
+ if not papers:
460
+ return text, ""
461
+
462
+ try:
463
+ cited_text = await self.citation_chain.ainvoke({
464
+ "text": text,
465
+ "papers": json.dumps(papers, indent=2)
466
+ })
467
+
468
+ # Use bibtexparser to aggregate BibTeX entries
469
+ bib_database = BibDatabase()
470
+ for p in papers:
471
+ if 'bibtex_entry' in p:
472
+ bib_db = bibtexparser.loads(p['bibtex_entry'])
473
+ if bib_db.entries:
474
+ bib_database.entries.append(bib_db.entries[0])
475
+ else:
476
+ logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
477
+ writer = BibTexWriter()
478
+ writer.indent = ' '
479
+ writer.comma_first = False
480
+ bibtex_entries = writer.write(bib_database).strip()
481
+
482
+ return cited_text, bibtex_entries
483
+ except Exception as e:
484
+ logger.error(f"Error inserting citations: {e}") # Replace print with logger
485
+ return text, ""
486
+
487
+ def create_gradio_interface() -> gr.Interface:
488
+ # Removed CitationGenerator initialization here
489
+ async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
490
+ use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
491
+ if not api_key.strip():
492
+ return "Please enter your Gemini API Key.", ""
493
+ if not text.strip():
494
+ return "Please enter text to process", ""
495
+ try:
496
+ config = Config(
497
+ gemini_api_key=api_key
498
+ )
499
+ citation_gen = CitationGenerator(config)
500
+ return await citation_gen.process_text(
501
+ text, num_queries, citations_per_query,
502
+ use_arxiv=use_arxiv, use_crossref=use_crossref
503
+ )
504
+ except ValueError as e:
505
+ return f"Input validation error: {str(e)}", ""
506
+ except Exception as e:
507
+ return f"Error: {str(e)}", ""
508
+
509
+ css = """
510
+ :root {
511
+ --primary: #6A7E76;
512
+ --primary-hover: #566961;
513
+ --bg: #FFFFFF;
514
+ --text: #454442;
515
+ --border: #B4B0AC;
516
+ --control-bg: #F5F3F0;
517
+ }
518
+
519
+ .container, .header, .input-group, .controls-row {
520
+ padding: 0.75rem;
521
+ }
522
+
523
+ .container {
524
+ max-width: 100%;
525
+ background: var(--bg);
526
+ }
527
+
528
+ .header {
529
+ text-align: center;
530
+ margin-bottom: 1rem;
531
+ background: var(--bg);
532
+ border-bottom: 1px solid var(--border);
533
+ }
534
+
535
+ .header h1 {
536
+ font-size: 1.5rem;
537
+ color: var(--primary);
538
+ font-weight: 500;
539
+ margin-bottom: 0.25rem;
540
+ }
541
+
542
+ .header p, label span {
543
+ font-size: 0.9rem;
544
+ color: var(--text);
545
+ }
546
+
547
+ .input-group {
548
+ border-radius: 4px;
549
+ border: 1px solid var(--border);
550
+ margin-bottom: 0.75rem;
551
+ }
552
+
553
+ .controls-row {
554
+ display: flex !important;
555
+ gap: 0.75rem;
556
+ margin-top: 0.5rem;
557
+ }
558
+
559
+ .source-controls {
560
+ display: flex;
561
+ gap: 0.75rem;
562
+ margin-top: 0.5rem;
563
+ }
564
+
565
+ .checkbox-group {
566
+ display: flex;
567
+ align-items: center;
568
+ gap: 0.5rem;
569
+ }
570
+
571
+ input[type="number"], textarea {
572
+ border: 1px solid var(--border);
573
+ border-radius: 4px;
574
+ padding: 0.5rem;
575
+ background: var(--control-bg);
576
+ color: var(--text);
577
+ font-size: 0.95rem;
578
+ }
579
+
580
+ .generate-btn {
581
+ background: var(--primary);
582
+ color: white;
583
+ padding: 0.5rem 1.5rem;
584
+ border-radius: 4px;
585
+ border: none;
586
+ font-size: 0.9rem;
587
+ transition: background 0.2s;
588
+ width: 100%;
589
+ }
590
+
591
+ .generate-btn:hover {
592
+ background: var(--primary-hover);
593
+ }
594
+
595
+ .output-container {
596
+ display: flex;
597
+ gap: 0.75rem;
598
+ }
599
+ """
600
+
601
+ with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
602
+ gr.HTML("""<div class="header">
603
+ <h1>📚 AutoCitation</h1>
604
+ <p>Insert citations into your academic text</p>
605
+ </div>""")
606
+
607
+ with gr.Group(elem_classes="input-group"):
608
+ # Added API Key input field
609
+ api_key = gr.Textbox(
610
+ label="Gemini API Key",
611
+ placeholder="Enter your Gemini API key...",
612
+ type="password",
613
+ interactive=True
614
+ )
615
+ input_text = gr.Textbox(
616
+ label="Input Text",
617
+ placeholder="Paste or type your text here...",
618
+ lines=8
619
+ )
620
+ with gr.Row(elem_classes="controls-row"):
621
+ with gr.Column(scale=1):
622
+ num_queries = gr.Number(
623
+ label="Search Queries",
624
+ value=3,
625
+ minimum=1,
626
+ maximum=5, # Changed to config.max_queries as 5
627
+ step=1
628
+ )
629
+ with gr.Column(scale=1):
630
+ citations_per_query = gr.Number(
631
+ label="Citations per Query",
632
+ value=1,
633
+ minimum=1,
634
+ maximum=10, # Changed to config.max_citations_per_query as 10
635
+ step=1
636
+ )
637
+
638
+ with gr.Row(elem_classes="source-controls"):
639
+ with gr.Column(scale=1):
640
+ use_arxiv = gr.Checkbox(
641
+ label="Search ArXiv",
642
+ value=True,
643
+ elem_classes="checkbox-group"
644
+ )
645
+ with gr.Column(scale=1):
646
+ use_crossref = gr.Checkbox(
647
+ label="Search CrossRef (Experimental)",
648
+ value=True,
649
+ elem_classes="checkbox-group"
650
+ )
651
+ with gr.Column(scale=2):
652
+ process_btn = gr.Button(
653
+ "Generate",
654
+ elem_classes="generate-btn"
655
+ )
656
+
657
+ with gr.Group(elem_classes="output-group"):
658
+ with gr.Row():
659
+ with gr.Column(scale=1):
660
+ cited_text = gr.Textbox(
661
+ label="Generated Text",
662
+ lines=10,
663
+ show_copy_button=True
664
+ )
665
+ with gr.Column(scale=1):
666
+ bibtex = gr.Textbox(
667
+ label="BibTeX References",
668
+ lines=10,
669
+ show_copy_button=True
670
+ )
671
+
672
+ # Updated the inputs and outputs
673
+ process_btn.click(
674
+ fn=process,
675
+ inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
676
+ outputs=[cited_text, bibtex]
677
+ )
678
+
679
+ return demo
680
+
681
+ if __name__ == "__main__":
682
+
683
+ demo = create_gradio_interface()
684
+ demo.launch(server_port=7860, share=False)