yipengsun commited on
Commit
9b14c03
·
verified ·
1 Parent(s): 39da4f4

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. .gitignore +44 -0
  3. autocitation.py +689 -0
  4. example.png +3 -0
  5. requirements.txt +13 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment variables
2
+ .env
3
+ *.pem
4
+
5
+ # Python
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ *.so
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # Virtual Environment
28
+ venv/
29
+ env/
30
+ ENV/
31
+
32
+ # IDE specific files
33
+ .idea/
34
+ .vscode/
35
+ *.swp
36
+ *.swo
37
+ .DS_Store
38
+
39
+ # Flask
40
+ instance/
41
+ .webassets-cache
42
+
43
+ # Logs
44
+ *.log
autocitation.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ import urllib.parse
5
+ import re
6
+ import xml.etree.ElementTree as ET
7
+ from dataclasses import dataclass, field
8
+ from typing import Dict, List, Optional
9
+ import sys
10
+ from loguru import logger
11
+
12
+ import aiohttp
13
+ import gradio as gr
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from langchain_core.prompts import PromptTemplate
16
+ from langchain_core.runnables import RunnablePassthrough
17
+ from langchain_google_genai import ChatGoogleGenerativeAI
18
+
19
+ import bibtexparser
20
+ from bibtexparser.bwriter import BibTexWriter
21
+ from bibtexparser.bibdatabase import BibDatabase
22
+
23
+ @dataclass
24
+ class Config:
25
+ gemini_api_key: str
26
+ max_retries: int = 3
27
+ base_delay: int = 1
28
+ max_queries: int = 5
29
+ max_citations_per_query: int = 10
30
+ arxiv_base_url: str = 'http://export.arxiv.org/api/query?'
31
+ crossref_base_url: str = 'https://api.crossref.org/works'
32
+ default_headers: dict = field(default_factory=lambda: {
33
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
34
+ })
35
+ log_level: str = 'DEBUG' # Add log level configuration
36
+
37
+ class ArxivXmlParser:
38
+ NS = {
39
+ 'atom': 'http://www.w3.org/2005/Atom',
40
+ 'arxiv': 'http://arxiv.org/schemas/atom'
41
+ }
42
+
43
+ def parse_papers(self, data: str) -> List[Dict]:
44
+ try:
45
+ root = ET.fromstring(data)
46
+ papers = []
47
+ for entry in root.findall('atom:entry', self.NS):
48
+ paper = self.parse_entry(entry)
49
+ if paper:
50
+ papers.append(paper)
51
+ return papers
52
+ except Exception as e:
53
+ print(f"Error parsing ArXiv XML: {e}")
54
+ return []
55
+
56
+ def parse_entry(self, entry) -> Optional[dict]:
57
+ try:
58
+ title_node = entry.find('atom:title', self.NS)
59
+ if title_node is None:
60
+ return None
61
+ title = title_node.text.strip()
62
+
63
+ authors = []
64
+ for author in entry.findall('atom:author', self.NS):
65
+ author_name_node = author.find('atom:name', self.NS)
66
+ if author_name_node is not None and author_name_node.text:
67
+ authors.append(self._format_author_name(author_name_node.text.strip()))
68
+
69
+ arxiv_id_node = entry.find('atom:id', self.NS)
70
+ if arxiv_id_node is None:
71
+ return None
72
+ arxiv_id = arxiv_id_node.text.split('/')[-1]
73
+
74
+ published_node = entry.find('atom:published', self.NS)
75
+ year = published_node.text[:4] if published_node is not None else "Unknown"
76
+
77
+ abstract_node = entry.find('atom:summary', self.NS)
78
+ abstract = abstract_node.text.strip() if abstract_node is not None else ""
79
+
80
+ bibtex_key = f"{authors[0].split(',')[0]}{arxiv_id.replace('.', '')}" if authors else f"unknown{arxiv_id.replace('.', '')}"
81
+ bibtex_entry = self._generate_bibtex_entry(bibtex_key, title, authors, arxiv_id, year)
82
+
83
+ return {
84
+ 'title': title,
85
+ 'authors': authors,
86
+ 'arxiv_id': arxiv_id,
87
+ 'published': year,
88
+ 'abstract': abstract,
89
+ 'bibtex_key': bibtex_key,
90
+ 'bibtex_entry': bibtex_entry
91
+ }
92
+ except Exception as e:
93
+ print(f"Error parsing ArXiv entry: {e}")
94
+ return None
95
+
96
+ @staticmethod
97
+ def _format_author_name(author: str) -> str:
98
+ names = author.split()
99
+ if len(names) > 1:
100
+ return f"{names[-1]}, {' '.join(names[:-1])}"
101
+ return author
102
+
103
+ def _generate_bibtex_entry(self, key: str, title: str, authors: List[str], arxiv_id: str, year: str) -> str:
104
+ db = BibDatabase()
105
+ db.entries = [{
106
+ 'ENTRYTYPE': 'article',
107
+ 'ID': key,
108
+ 'title': title,
109
+ 'author': ' and '.join(authors),
110
+ 'journal': f'arXiv preprint arXiv:{arxiv_id}',
111
+ 'year': year
112
+ }]
113
+ writer = BibTexWriter()
114
+ writer.indent = ' ' # Indentation for entries
115
+ writer.comma_first = False # Place the comma at the end of lines
116
+ return writer.write(db).strip()
117
+
118
+ class AsyncContextManager:
119
+ async def __aenter__(self):
120
+ self._session = aiohttp.ClientSession()
121
+ return self._session
122
+
123
+ async def __aexit__(self, *_):
124
+ if self._session:
125
+ await self._session.close()
126
+
127
+ class CitationGenerator:
128
+ def __init__(self, config: Config):
129
+ self.config = config
130
+ self.xml_parser = ArxivXmlParser()
131
+ self.async_context = AsyncContextManager()
132
+ self.llm = ChatGoogleGenerativeAI(
133
+ model="gemini-1.5-flash",
134
+ temperature=0.3,
135
+ google_api_key=config.gemini_api_key,
136
+ streaming=True
137
+ )
138
+ self.citation_chain = self._create_citation_chain()
139
+ self.generate_queries_chain = self._create_generate_queries_chain()
140
+ logger.remove()
141
+ logger.add(sys.stderr, level=config.log_level) # Configure logger
142
+
143
+ def _create_citation_chain(self):
144
+ citation_prompt = PromptTemplate.from_template(
145
+ """Insert citations into the provided text using LaTeX \\cite{{key}} commands.
146
+
147
+ You must not alter the original wording or structure of the text beyond adding citations.
148
+ You must include all provided references at least once. Place citations at suitable points.
149
+
150
+ Input text:
151
+ {text}
152
+
153
+ Available papers (cite each at least once):
154
+ {papers}
155
+ """
156
+ )
157
+ return (
158
+ {"text": RunnablePassthrough(), "papers": RunnablePassthrough()}
159
+ | citation_prompt
160
+ | self.llm
161
+ | StrOutputParser()
162
+ )
163
+
164
+ def _create_generate_queries_chain(self):
165
+ generate_queries_prompt = PromptTemplate.from_template(
166
+ """Generate {num_queries} diverse academic search queries based on the given text.
167
+ The queries should be concise and relevant.
168
+
169
+ Requirements:
170
+ 1. Return ONLY a valid JSON array of strings.
171
+ 2. No additional text or formatting beyond JSON.
172
+ 3. Ensure uniqueness.
173
+
174
+ Text: {text}
175
+ """
176
+ )
177
+ return (
178
+ {"text": RunnablePassthrough(), "num_queries": RunnablePassthrough()}
179
+ | generate_queries_prompt
180
+ | self.llm
181
+ | StrOutputParser()
182
+ )
183
+
184
+ async def generate_queries(self, text: str, num_queries: int) -> List[str]:
185
+ try:
186
+ response = await self.generate_queries_chain.ainvoke({
187
+ "text": text,
188
+ "num_queries": num_queries
189
+ })
190
+
191
+ content = response.strip()
192
+ if not content.startswith('['):
193
+ start = content.find('[')
194
+ end = content.rfind(']') + 1
195
+ if start >= 0 and end > start:
196
+ content = content[start:end]
197
+ try:
198
+ queries = json.loads(content)
199
+ if isinstance(queries, list):
200
+ return [q.strip() for q in queries if isinstance(q, str)][:num_queries]
201
+ except json.JSONDecodeError:
202
+ lines = [line.strip() for line in content.split('\n')
203
+ if line.strip() and not line.strip().startswith(('[', ']'))]
204
+ return lines[:num_queries]
205
+
206
+ return ["deep learning neural networks"]
207
+
208
+ except Exception as e:
209
+ logger.error(f"Error generating queries: {e}") # Replace print with logger
210
+ return ["deep learning neural networks"]
211
+
212
+ async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
213
+ try:
214
+ params = {
215
+ 'search_query': f'all:{urllib.parse.quote(query)}',
216
+ 'start': 0,
217
+ 'max_results': max_results,
218
+ 'sortBy': 'relevance',
219
+ 'sortOrder': 'descending'
220
+ }
221
+
222
+ async with session.get(
223
+ self.config.arxiv_base_url + urllib.parse.urlencode(params),
224
+ headers=self.config.default_headers,
225
+ timeout=30
226
+ ) as response:
227
+ text_data = await response.text()
228
+ papers = self.xml_parser.parse_papers(text_data)
229
+ return papers
230
+ except Exception as e:
231
+ logger.error(f"Error searching ArXiv: {e}")
232
+ return []
233
+
234
+ async def fix_author_name(self, author: str) -> str:
235
+ if not re.search(r'[�]', author):
236
+ return author
237
+
238
+ try:
239
+ prompt = f"""Fix this author name that contains corrupted characters (�):
240
+
241
+ Name: {author}
242
+
243
+ Requirements:
244
+ 1. Return ONLY the fixed author name
245
+ 2. Use proper diacritical marks for names
246
+ 3. Consider common name patterns and languages
247
+ 4. If unsure about a character, use the most likely letter
248
+ 5. Maintain the format: "Lastname, Firstname"
249
+
250
+ Example fixes:
251
+ - "Gonz�lez" -> "González"
252
+ - "Cristi�n" -> "Cristián"
253
+ """
254
+
255
+ response = await self.llm.ainvoke([{"role": "user", "content": prompt}])
256
+ fixed_name = response.content.strip()
257
+ return fixed_name if fixed_name else author
258
+
259
+ except Exception as e:
260
+ logger.error(f"Error fixing author name: {e}")
261
+ return author
262
+
263
+ async def format_bibtex_author_names(self, text: str) -> str:
264
+ try:
265
+ bib_database = bibtexparser.loads(text)
266
+ for entry in bib_database.entries:
267
+ if 'author' in entry:
268
+ authors = entry['author'].split(' and ')
269
+ cleaned_authors = []
270
+ for author in authors:
271
+ fixed_author = await self.fix_author_name(author)
272
+ cleaned_authors.append(fixed_author)
273
+ entry['author'] = ' and '.join(cleaned_authors)
274
+ writer = BibTexWriter()
275
+ writer.indent = ' '
276
+ writer.comma_first = False
277
+ return writer.write(bib_database).strip()
278
+ except Exception as e:
279
+ logger.error(f"Error cleaning BibTeX special characters: {e}")
280
+ return text
281
+
282
+ async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
283
+ try:
284
+ cleaned_query = query.replace("'", "").replace('"', "")
285
+ if ' ' in cleaned_query:
286
+ cleaned_query = f'"{cleaned_query}"'
287
+
288
+ params = {
289
+ 'query.bibliographic': cleaned_query,
290
+ 'rows': max_results,
291
+ 'select': 'DOI,title,author,published-print,container-title',
292
+ 'sort': 'relevance',
293
+ 'order': 'desc'
294
+ }
295
+
296
+ headers = {
297
+ 'User-Agent': 'Mozilla/5.0 (compatible; CitationBot/1.0; mailto:[email protected])',
298
+ 'Accept': 'application/json'
299
+ }
300
+
301
+ for attempt in range(self.config.max_retries):
302
+ try:
303
+ async with session.get(
304
+ self.config.crossref_base_url,
305
+ params=params,
306
+ headers=headers,
307
+ timeout=30
308
+ ) as response:
309
+ if response.status == 429:
310
+ delay = self.config.base_delay * (2 ** attempt)
311
+ logger.warning(f"Rate limited by CrossRef. Retrying in {delay} seconds...")
312
+ await asyncio.sleep(delay)
313
+ continue
314
+
315
+ response.raise_for_status()
316
+ search_data = await response.json()
317
+ items = search_data.get('message', {}).get('items', [])
318
+
319
+ if not items:
320
+ return []
321
+
322
+ papers = []
323
+ existing_keys = set()
324
+ for item in items:
325
+ doi = item.get('DOI')
326
+ if not doi:
327
+ continue
328
+
329
+ try:
330
+ bibtex_url = f"https://doi.org/{doi}"
331
+ async with session.get(
332
+ bibtex_url,
333
+ headers={
334
+ 'Accept': 'application/x-bibtex',
335
+ 'User-Agent': 'Mozilla/5.0 (compatible; CitationBot/1.0; mailto:[email protected])'
336
+ },
337
+ timeout=30
338
+ ) as bibtex_response:
339
+ if bibtex_response.status != 200:
340
+ continue
341
+
342
+ bibtex_text = await bibtex_response.text()
343
+
344
+ # Parse the BibTeX entry
345
+ bib_database = bibtexparser.loads(bibtex_text)
346
+ if not bib_database.entries:
347
+ continue
348
+ entry = bib_database.entries[0]
349
+
350
+ # Check if 'title' or 'booktitle' is present
351
+ if 'title' not in entry and 'booktitle' not in entry:
352
+ continue # Skip entries without 'title' or 'booktitle'
353
+
354
+ # Check if 'author' is present
355
+ if 'author' not in entry:
356
+ continue # Skip entries without 'author'
357
+
358
+ # Extract necessary fields
359
+ title = entry.get('title', 'No Title').replace('{', '').replace('}', '')
360
+ authors = entry.get('author', 'Unknown').replace('\n', ' ').replace('\t', ' ').strip()
361
+ year = entry.get('year', 'Unknown')
362
+
363
+ # Generate a unique BibTeX key
364
+ key = self._generate_unique_bibtex_key(entry, existing_keys)
365
+ entry['ID'] = key
366
+ existing_keys.add(key)
367
+
368
+ # Use BibTexWriter to format the entry
369
+ writer = BibTexWriter()
370
+ writer.indent = ' '
371
+ writer.comma_first = False
372
+ formatted_bibtex = writer.write(bib_database).strip()
373
+
374
+ papers.append({
375
+ 'title': title,
376
+ 'authors': authors,
377
+ 'year': year,
378
+ 'bibtex_key': key,
379
+ 'bibtex_entry': formatted_bibtex
380
+ })
381
+ except Exception as e:
382
+ logger.error(f"Error processing CrossRef item: {e}") # Replace print with logger
383
+ continue
384
+
385
+ return papers
386
+
387
+ except aiohttp.ClientError as e:
388
+ if attempt == self.config.max_retries - 1:
389
+ logger.error(f"Max retries reached for CrossRef search. Error: {e}")
390
+ raise
391
+ delay = self.config.base_delay * (2 ** attempt)
392
+ logger.warning(f"Client error during CrossRef search: {e}. Retrying in {delay} seconds...")
393
+ await asyncio.sleep(delay)
394
+
395
+ except Exception as e:
396
+ logger.error(f"Error searching CrossRef: {e}") # Replace print with logger
397
+ return []
398
+
399
+ def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
400
+ entry_type = entry.get('ENTRYTYPE', '').lower()
401
+ author_field = entry.get('author', '')
402
+ year = entry.get('year', '')
403
+ authors = [a.strip() for a in author_field.split(' and ')]
404
+ first_author_last_name = authors[0].split(',')[0] if authors else 'unknown'
405
+
406
+ if entry_type == 'inbook':
407
+ # Use 'booktitle' for 'inbook' entries
408
+ booktitle = entry.get('booktitle', '')
409
+ title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
410
+ else:
411
+ # Use regular 'title' for other entries
412
+ title = entry.get('title', '')
413
+ title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
414
+
415
+ base_key = f"{first_author_last_name}{year}{title_word}"
416
+ # Ensure the key is unique
417
+ key = base_key
418
+ index = 1
419
+ while key in existing_keys:
420
+ key = f"{base_key}{index}"
421
+ index += 1
422
+ return key
423
+
424
+ async def process_text(self, text: str, num_queries: int, citations_per_query: int,
425
+ use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str]:
426
+ if not (use_arxiv or use_crossref):
427
+ return "Please select at least one source (ArXiv or CrossRef)", ""
428
+
429
+ num_queries = min(max(1, num_queries), self.config.max_queries)
430
+ citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
431
+
432
+ queries = await self.generate_queries(text, num_queries)
433
+ if not queries:
434
+ return text, ""
435
+
436
+ async with self.async_context as session:
437
+ search_tasks = []
438
+ for query in queries:
439
+ if use_arxiv:
440
+ search_tasks.append(self.search_arxiv(session, query, citations_per_query))
441
+ if use_crossref:
442
+ search_tasks.append(self.search_crossref(session, query, citations_per_query))
443
+
444
+ results = await asyncio.gather(*search_tasks, return_exceptions=True)
445
+
446
+ papers = []
447
+ for r in results:
448
+ if not isinstance(r, Exception):
449
+ papers.extend(r)
450
+
451
+ unique_papers = []
452
+ seen_keys = set()
453
+ for p in papers:
454
+ if p['bibtex_key'] not in seen_keys:
455
+ seen_keys.add(p['bibtex_key'])
456
+ unique_papers.append(p)
457
+ papers = unique_papers
458
+
459
+ if not papers:
460
+ return text, ""
461
+
462
+ try:
463
+ cited_text = await self.citation_chain.ainvoke({
464
+ "text": text,
465
+ "papers": json.dumps(papers, indent=2)
466
+ })
467
+
468
+ # Use bibtexparser to aggregate BibTeX entries
469
+ bib_database = BibDatabase()
470
+ for p in papers:
471
+ if 'bibtex_entry' in p:
472
+ bib_db = bibtexparser.loads(p['bibtex_entry'])
473
+ if bib_db.entries:
474
+ bib_database.entries.append(bib_db.entries[0])
475
+ else:
476
+ logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
477
+ writer = BibTexWriter()
478
+ writer.indent = ' '
479
+ writer.comma_first = False
480
+ bibtex_entries = writer.write(bib_database).strip()
481
+
482
+ return cited_text, bibtex_entries
483
+ except Exception as e:
484
+ logger.error(f"Error inserting citations: {e}") # Replace print with logger
485
+ return text, ""
486
+
487
+ def create_gradio_interface() -> gr.Interface:
488
+ # Removed CitationGenerator initialization here
489
+ async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
490
+ use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
491
+ if not api_key.strip():
492
+ return "Please enter your Gemini API Key.", ""
493
+ if not text.strip():
494
+ return "Please enter text to process", ""
495
+ try:
496
+ config = Config(
497
+ gemini_api_key=api_key
498
+ )
499
+ citation_gen = CitationGenerator(config)
500
+ return await citation_gen.process_text(
501
+ text, num_queries, citations_per_query,
502
+ use_arxiv=use_arxiv, use_crossref=use_crossref
503
+ )
504
+ except ValueError as e:
505
+ return f"Input validation error: {str(e)}", ""
506
+ except Exception as e:
507
+ return f"Error: {str(e)}", ""
508
+
509
+ css = """
510
+ :root {
511
+ --primary: #6A7E76;
512
+ --primary-hover: #566961;
513
+ --bg: #FFFFFF;
514
+ --text: #454442;
515
+ --border: #B4B0AC;
516
+ --control-bg: #F5F3F0;
517
+ }
518
+
519
+ .container, .header, .input-group, .controls-row {
520
+ padding: 0.75rem;
521
+ }
522
+
523
+ .container {
524
+ max-width: 100%;
525
+ background: var(--bg);
526
+ }
527
+
528
+ .header {
529
+ text-align: center;
530
+ margin-bottom: 1rem;
531
+ background: var(--bg);
532
+ border-bottom: 1px solid var(--border);
533
+ }
534
+
535
+ .header h1 {
536
+ font-size: 1.5rem;
537
+ color: var(--primary);
538
+ font-weight: 500;
539
+ margin-bottom: 0.25rem;
540
+ }
541
+
542
+ .header p, label span {
543
+ font-size: 0.9rem;
544
+ color: var(--text);
545
+ }
546
+
547
+ .input-group {
548
+ border-radius: 4px;
549
+ border: 1px solid var(--border);
550
+ margin-bottom: 0.75rem;
551
+ }
552
+
553
+ .controls-row {
554
+ display: flex !important;
555
+ gap: 0.75rem;
556
+ margin-top: 0.5rem;
557
+ }
558
+
559
+ .source-controls {
560
+ display: flex;
561
+ gap: 0.75rem;
562
+ margin-top: 0.5rem;
563
+ }
564
+
565
+ .checkbox-group {
566
+ display: flex;
567
+ align-items: center;
568
+ gap: 0.5rem;
569
+ }
570
+
571
+ input[type="number"], textarea {
572
+ border: 1px solid var(--border);
573
+ border-radius: 4px;
574
+ padding: 0.5rem;
575
+ background: var(--control-bg);
576
+ color: var(--text);
577
+ font-size: 0.95rem;
578
+ }
579
+
580
+ .generate-btn {
581
+ background: var(--primary);
582
+ color: white;
583
+ padding: 0.5rem 1.5rem;
584
+ border-radius: 4px;
585
+ border: none;
586
+ font-size: 0.9rem;
587
+ transition: background 0.2s;
588
+ width: 100%;
589
+ }
590
+
591
+ .generate-btn:hover {
592
+ background: var(--primary-hover);
593
+ }
594
+
595
+ .output-container {
596
+ display: flex;
597
+ gap: 0.75rem;
598
+ }
599
+ """
600
+
601
+ with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
602
+ gr.HTML("""<div class="header">
603
+ <h1>📚 AutoCitation</h1>
604
+ <p>Insert citations into your academic text</p>
605
+ </div>""")
606
+
607
+ with gr.Group(elem_classes="input-group"):
608
+ # Added API Key input field
609
+ api_key = gr.Textbox(
610
+ label="Gemini API Key",
611
+ placeholder="Enter your Gemini API key...",
612
+ type="password",
613
+ interactive=True
614
+ )
615
+ input_text = gr.Textbox(
616
+ label="Input Text",
617
+ placeholder="Paste or type your text here...",
618
+ lines=8
619
+ )
620
+ with gr.Row(elem_classes="controls-row"):
621
+ with gr.Column(scale=1):
622
+ num_queries = gr.Number(
623
+ label="Search Queries",
624
+ value=3,
625
+ minimum=1,
626
+ maximum=5, # Changed to config.max_queries as 5
627
+ step=1
628
+ )
629
+ with gr.Column(scale=1):
630
+ citations_per_query = gr.Number(
631
+ label="Citations per Query",
632
+ value=1,
633
+ minimum=1,
634
+ maximum=10, # Changed to config.max_citations_per_query as 10
635
+ step=1
636
+ )
637
+
638
+ with gr.Row(elem_classes="source-controls"):
639
+ with gr.Column(scale=1):
640
+ use_arxiv = gr.Checkbox(
641
+ label="Search ArXiv",
642
+ value=True,
643
+ elem_classes="checkbox-group"
644
+ )
645
+ with gr.Column(scale=1):
646
+ use_crossref = gr.Checkbox(
647
+ label="Search CrossRef (Experimental)",
648
+ value=True,
649
+ elem_classes="checkbox-group"
650
+ )
651
+ with gr.Column(scale=2):
652
+ process_btn = gr.Button(
653
+ "Generate",
654
+ elem_classes="generate-btn"
655
+ )
656
+
657
+ with gr.Group(elem_classes="output-group"):
658
+ with gr.Row():
659
+ with gr.Column(scale=1):
660
+ cited_text = gr.Textbox(
661
+ label="Generated Text",
662
+ lines=10,
663
+ show_copy_button=True
664
+ )
665
+ with gr.Column(scale=1):
666
+ bibtex = gr.Textbox(
667
+ label="BibTeX References",
668
+ lines=10,
669
+ show_copy_button=True
670
+ )
671
+
672
+ # Updated the inputs and outputs
673
+ process_btn.click(
674
+ fn=process,
675
+ inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
676
+ outputs=[cited_text, bibtex]
677
+ )
678
+
679
+ return demo
680
+
681
+ if __name__ == "__main__":
682
+ # Removed environment variable loading and config initialization
683
+ demo = create_gradio_interface()
684
+ try:
685
+ demo.launch(server_port=7860, share=False)
686
+ except KeyboardInterrupt:
687
+ print("\nShutting down server...")
688
+ except Exception as e:
689
+ print(f"Error starting server: {str(e)}")
example.png ADDED

Git LFS Details

  • SHA256: 0832e3d4e5e84a16787334a3d4e1ad390012ffd144838161a3afba15e912f167
  • Pointer size: 133 Bytes
  • Size of remote file: 12.8 MB
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp>=3.9.0
2
+ gradio>=4.0.0
3
+ langchain-core>=0.1.0
4
+ langchain-google-genai>=0.0.5
5
+ loguru>=0.7.0
6
+ python-dotenv>=1.0.0
7
+ customtkinter>=5.2.0
8
+ google-generativeai>=0.3.0
9
+ pyperclip>=1.8.2
10
+ tqdm>=4.66.0
11
+ backoff>=2.2.1
12
+ zhipuai>=0.2.0
13
+ bibtexparser>=1.2.0