Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,9 +11,9 @@ from loguru import logger
|
|
11 |
|
12 |
import aiohttp
|
13 |
import gradio as gr
|
14 |
-
|
15 |
-
from
|
16 |
-
|
17 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
18 |
|
19 |
import bibtexparser
|
@@ -32,7 +32,7 @@ class Config:
|
|
32 |
default_headers: dict = field(default_factory=lambda: {
|
33 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
34 |
})
|
35 |
-
log_level: str = 'DEBUG'
|
36 |
|
37 |
class ArxivXmlParser:
|
38 |
NS = {
|
@@ -111,8 +111,8 @@ class ArxivXmlParser:
|
|
111 |
'year': year
|
112 |
}]
|
113 |
writer = BibTexWriter()
|
114 |
-
writer.indent = ' '
|
115 |
-
writer.comma_first = False
|
116 |
return writer.write(db).strip()
|
117 |
|
118 |
class AsyncContextManager:
|
@@ -135,13 +135,7 @@ class CitationGenerator:
|
|
135 |
google_api_key=config.gemini_api_key,
|
136 |
streaming=True
|
137 |
)
|
138 |
-
self.
|
139 |
-
self.generate_queries_chain = self._create_generate_queries_chain()
|
140 |
-
logger.remove()
|
141 |
-
logger.add(sys.stderr, level=config.log_level) # Configure logger
|
142 |
-
|
143 |
-
def _create_citation_chain(self):
|
144 |
-
citation_prompt = PromptTemplate.from_template(
|
145 |
"""Insert citations into the provided text using LaTeX \\cite{{key}} commands.
|
146 |
|
147 |
You must not alter the original wording or structure of the text beyond adding citations.
|
@@ -154,15 +148,8 @@ class CitationGenerator:
|
|
154 |
{papers}
|
155 |
"""
|
156 |
)
|
157 |
-
|
158 |
-
|
159 |
-
| citation_prompt
|
160 |
-
| self.llm
|
161 |
-
| StrOutputParser()
|
162 |
-
)
|
163 |
-
|
164 |
-
def _create_generate_queries_chain(self):
|
165 |
-
generate_queries_prompt = PromptTemplate.from_template(
|
166 |
"""Generate {num_queries} diverse academic search queries based on the given text.
|
167 |
The queries should be concise and relevant.
|
168 |
|
@@ -174,20 +161,18 @@ class CitationGenerator:
|
|
174 |
Text: {text}
|
175 |
"""
|
176 |
)
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
| self.llm
|
181 |
-
| StrOutputParser()
|
182 |
-
)
|
183 |
|
184 |
async def generate_queries(self, text: str, num_queries: int) -> List[str]:
|
|
|
|
|
|
|
|
|
185 |
try:
|
186 |
-
|
187 |
-
|
188 |
-
"num_queries": num_queries
|
189 |
-
})
|
190 |
-
|
191 |
content = response.strip()
|
192 |
if not content.startswith('['):
|
193 |
start = content.find('[')
|
@@ -206,7 +191,7 @@ class CitationGenerator:
|
|
206 |
return ["deep learning neural networks"]
|
207 |
|
208 |
except Exception as e:
|
209 |
-
logger.error(f"Error generating queries: {e}")
|
210 |
return ["deep learning neural networks"]
|
211 |
|
212 |
async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
|
@@ -218,7 +203,6 @@ class CitationGenerator:
|
|
218 |
'sortBy': 'relevance',
|
219 |
'sortOrder': 'descending'
|
220 |
}
|
221 |
-
|
222 |
async with session.get(
|
223 |
self.config.arxiv_base_url + urllib.parse.urlencode(params),
|
224 |
headers=self.config.default_headers,
|
@@ -234,7 +218,6 @@ class CitationGenerator:
|
|
234 |
async def fix_author_name(self, author: str) -> str:
|
235 |
if not re.search(r'[�]', author):
|
236 |
return author
|
237 |
-
|
238 |
try:
|
239 |
prompt = f"""Fix this author name that contains corrupted characters (�):
|
240 |
|
@@ -244,18 +227,12 @@ class CitationGenerator:
|
|
244 |
1. Return ONLY the fixed author name
|
245 |
2. Use proper diacritical marks for names
|
246 |
3. Consider common name patterns and languages
|
247 |
-
4. If unsure
|
248 |
5. Maintain the format: "Lastname, Firstname"
|
249 |
-
|
250 |
-
Example fixes:
|
251 |
-
- "Gonz�lez" -> "González"
|
252 |
-
- "Cristi�n" -> "Cristi��n"
|
253 |
"""
|
254 |
-
|
255 |
-
|
256 |
-
fixed_name = response.content.strip()
|
257 |
return fixed_name if fixed_name else author
|
258 |
-
|
259 |
except Exception as e:
|
260 |
logger.error(f"Error fixing author name: {e}")
|
261 |
return author
|
@@ -276,7 +253,7 @@ class CitationGenerator:
|
|
276 |
writer.comma_first = False
|
277 |
return writer.write(bib_database).strip()
|
278 |
except Exception as e:
|
279 |
-
logger.error(f"Error cleaning BibTeX special characters: {e}")
|
280 |
return text
|
281 |
|
282 |
async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
|
@@ -341,31 +318,24 @@ class CitationGenerator:
|
|
341 |
|
342 |
bibtex_text = await bibtex_response.text()
|
343 |
|
344 |
-
# Parse the BibTeX entry
|
345 |
bib_database = bibtexparser.loads(bibtex_text)
|
346 |
if not bib_database.entries:
|
347 |
continue
|
348 |
entry = bib_database.entries[0]
|
349 |
|
350 |
-
# Check if 'title' or 'booktitle' is present
|
351 |
if 'title' not in entry and 'booktitle' not in entry:
|
352 |
-
continue
|
353 |
-
|
354 |
-
# Check if 'author' is present
|
355 |
if 'author' not in entry:
|
356 |
-
continue
|
357 |
|
358 |
-
# Extract necessary fields
|
359 |
title = entry.get('title', 'No Title').replace('{', '').replace('}', '')
|
360 |
authors = entry.get('author', 'Unknown').replace('\n', ' ').replace('\t', ' ').strip()
|
361 |
year = entry.get('year', 'Unknown')
|
362 |
|
363 |
-
# Generate a unique BibTeX key
|
364 |
key = self._generate_unique_bibtex_key(entry, existing_keys)
|
365 |
entry['ID'] = key
|
366 |
existing_keys.add(key)
|
367 |
|
368 |
-
# Use BibTexWriter to format the entry
|
369 |
writer = BibTexWriter()
|
370 |
writer.indent = ' '
|
371 |
writer.comma_first = False
|
@@ -378,10 +348,9 @@ class CitationGenerator:
|
|
378 |
'bibtex_key': key,
|
379 |
'bibtex_entry': formatted_bibtex
|
380 |
})
|
381 |
-
except Exception as e:
|
382 |
-
logger.error(f"Error processing CrossRef item: {e}") # Replace print with logger
|
383 |
-
continue
|
384 |
|
|
|
|
|
385 |
return papers
|
386 |
|
387 |
except aiohttp.ClientError as e:
|
@@ -393,7 +362,7 @@ class CitationGenerator:
|
|
393 |
await asyncio.sleep(delay)
|
394 |
|
395 |
except Exception as e:
|
396 |
-
logger.error(f"Error searching CrossRef: {e}")
|
397 |
return []
|
398 |
|
399 |
def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
|
@@ -402,18 +371,15 @@ class CitationGenerator:
|
|
402 |
year = entry.get('year', '')
|
403 |
authors = [a.strip() for a in author_field.split(' and ')]
|
404 |
first_author_last_name = authors[0].split(',')[0] if authors else 'unknown'
|
405 |
-
|
406 |
if entry_type == 'inbook':
|
407 |
-
# Use 'booktitle' for 'inbook' entries
|
408 |
booktitle = entry.get('booktitle', '')
|
409 |
title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
|
410 |
else:
|
411 |
-
# Use regular 'title' for other entries
|
412 |
title = entry.get('title', '')
|
413 |
title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
|
414 |
-
|
415 |
base_key = f"{first_author_last_name}{year}{title_word}"
|
416 |
-
# Ensure the key is unique
|
417 |
key = base_key
|
418 |
index = 1
|
419 |
while key in existing_keys:
|
@@ -422,70 +388,93 @@ class CitationGenerator:
|
|
422 |
return key
|
423 |
|
424 |
async def process_text(self, text: str, num_queries: int, citations_per_query: int,
|
425 |
-
|
426 |
if not (use_arxiv or use_crossref):
|
427 |
return "Please select at least one source (ArXiv or CrossRef)", ""
|
428 |
|
429 |
num_queries = min(max(1, num_queries), self.config.max_queries)
|
430 |
citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
|
431 |
|
432 |
-
|
433 |
-
|
434 |
-
return text, ""
|
435 |
-
|
436 |
-
async with self.async_context as session:
|
437 |
-
search_tasks = []
|
438 |
-
for query in queries:
|
439 |
-
if use_arxiv:
|
440 |
-
search_tasks.append(self.search_arxiv(session, query, citations_per_query))
|
441 |
-
if use_crossref:
|
442 |
-
search_tasks.append(self.search_crossref(session, query, citations_per_query))
|
443 |
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
|
|
|
|
|
|
|
|
|
|
461 |
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
|
477 |
-
writer = BibTexWriter()
|
478 |
-
writer.indent = ' '
|
479 |
-
writer.comma_first = False
|
480 |
-
bibtex_entries = writer.write(bib_database).strip()
|
481 |
-
|
482 |
-
return cited_text, bibtex_entries
|
483 |
-
except Exception as e:
|
484 |
-
logger.error(f"Error inserting citations: {e}") # Replace print with logger
|
485 |
-
return text, ""
|
486 |
|
487 |
def create_gradio_interface() -> gr.Interface:
|
488 |
-
# Removed CitationGenerator initialization here
|
489 |
async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
|
490 |
use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
|
491 |
if not api_key.strip():
|
@@ -600,12 +589,11 @@ def create_gradio_interface() -> gr.Interface:
|
|
600 |
|
601 |
with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
|
602 |
gr.HTML("""<div class="header">
|
603 |
-
<h1>📚 AutoCitation</h1>
|
604 |
-
<p>Insert citations into your academic text</p>
|
605 |
-
</div>""")
|
606 |
|
607 |
with gr.Group(elem_classes="input-group"):
|
608 |
-
# Added API Key input field
|
609 |
api_key = gr.Textbox(
|
610 |
label="Gemini API Key",
|
611 |
placeholder="Enter your Gemini API key...",
|
@@ -623,7 +611,7 @@ def create_gradio_interface() -> gr.Interface:
|
|
623 |
label="Search Queries",
|
624 |
value=3,
|
625 |
minimum=1,
|
626 |
-
maximum=
|
627 |
step=1
|
628 |
)
|
629 |
with gr.Column(scale=1):
|
@@ -631,7 +619,7 @@ def create_gradio_interface() -> gr.Interface:
|
|
631 |
label="Citations per Query",
|
632 |
value=1,
|
633 |
minimum=1,
|
634 |
-
maximum=
|
635 |
step=1
|
636 |
)
|
637 |
|
@@ -669,7 +657,6 @@ def create_gradio_interface() -> gr.Interface:
|
|
669 |
show_copy_button=True
|
670 |
)
|
671 |
|
672 |
-
# Updated the inputs and outputs
|
673 |
process_btn.click(
|
674 |
fn=process,
|
675 |
inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
|
@@ -679,6 +666,10 @@ def create_gradio_interface() -> gr.Interface:
|
|
679 |
return demo
|
680 |
|
681 |
if __name__ == "__main__":
|
682 |
-
|
683 |
demo = create_gradio_interface()
|
684 |
-
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
import aiohttp
|
13 |
import gradio as gr
|
14 |
+
|
15 |
+
from langchain.prompts import PromptTemplate
|
16 |
+
|
17 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
18 |
|
19 |
import bibtexparser
|
|
|
32 |
default_headers: dict = field(default_factory=lambda: {
|
33 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
34 |
})
|
35 |
+
log_level: str = 'DEBUG'
|
36 |
|
37 |
class ArxivXmlParser:
|
38 |
NS = {
|
|
|
111 |
'year': year
|
112 |
}]
|
113 |
writer = BibTexWriter()
|
114 |
+
writer.indent = ' '
|
115 |
+
writer.comma_first = False
|
116 |
return writer.write(db).strip()
|
117 |
|
118 |
class AsyncContextManager:
|
|
|
135 |
google_api_key=config.gemini_api_key,
|
136 |
streaming=True
|
137 |
)
|
138 |
+
self.citation_prompt = PromptTemplate.from_template(
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
"""Insert citations into the provided text using LaTeX \\cite{{key}} commands.
|
140 |
|
141 |
You must not alter the original wording or structure of the text beyond adding citations.
|
|
|
148 |
{papers}
|
149 |
"""
|
150 |
)
|
151 |
+
|
152 |
+
self.generate_queries_prompt = PromptTemplate.from_template(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
"""Generate {num_queries} diverse academic search queries based on the given text.
|
154 |
The queries should be concise and relevant.
|
155 |
|
|
|
161 |
Text: {text}
|
162 |
"""
|
163 |
)
|
164 |
+
|
165 |
+
logger.remove()
|
166 |
+
logger.add(sys.stderr, level=config.log_level)
|
|
|
|
|
|
|
167 |
|
168 |
async def generate_queries(self, text: str, num_queries: int) -> List[str]:
|
169 |
+
input_map = {
|
170 |
+
"text": text,
|
171 |
+
"num_queries": num_queries
|
172 |
+
}
|
173 |
try:
|
174 |
+
prompt = self.generate_queries_prompt.format(**input_map)
|
175 |
+
response = await self.llm.apredict(prompt)
|
|
|
|
|
|
|
176 |
content = response.strip()
|
177 |
if not content.startswith('['):
|
178 |
start = content.find('[')
|
|
|
191 |
return ["deep learning neural networks"]
|
192 |
|
193 |
except Exception as e:
|
194 |
+
logger.error(f"Error generating queries: {e}")
|
195 |
return ["deep learning neural networks"]
|
196 |
|
197 |
async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
|
|
|
203 |
'sortBy': 'relevance',
|
204 |
'sortOrder': 'descending'
|
205 |
}
|
|
|
206 |
async with session.get(
|
207 |
self.config.arxiv_base_url + urllib.parse.urlencode(params),
|
208 |
headers=self.config.default_headers,
|
|
|
218 |
async def fix_author_name(self, author: str) -> str:
|
219 |
if not re.search(r'[�]', author):
|
220 |
return author
|
|
|
221 |
try:
|
222 |
prompt = f"""Fix this author name that contains corrupted characters (�):
|
223 |
|
|
|
227 |
1. Return ONLY the fixed author name
|
228 |
2. Use proper diacritical marks for names
|
229 |
3. Consider common name patterns and languages
|
230 |
+
4. If unsure, use the most likely letter
|
231 |
5. Maintain the format: "Lastname, Firstname"
|
|
|
|
|
|
|
|
|
232 |
"""
|
233 |
+
response = await self.llm.apredict(prompt)
|
234 |
+
fixed_name = response.strip()
|
|
|
235 |
return fixed_name if fixed_name else author
|
|
|
236 |
except Exception as e:
|
237 |
logger.error(f"Error fixing author name: {e}")
|
238 |
return author
|
|
|
253 |
writer.comma_first = False
|
254 |
return writer.write(bib_database).strip()
|
255 |
except Exception as e:
|
256 |
+
logger.error(f"Error cleaning BibTeX special characters: {e}")
|
257 |
return text
|
258 |
|
259 |
async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
|
|
|
318 |
|
319 |
bibtex_text = await bibtex_response.text()
|
320 |
|
|
|
321 |
bib_database = bibtexparser.loads(bibtex_text)
|
322 |
if not bib_database.entries:
|
323 |
continue
|
324 |
entry = bib_database.entries[0]
|
325 |
|
|
|
326 |
if 'title' not in entry and 'booktitle' not in entry:
|
327 |
+
continue
|
|
|
|
|
328 |
if 'author' not in entry:
|
329 |
+
continue
|
330 |
|
|
|
331 |
title = entry.get('title', 'No Title').replace('{', '').replace('}', '')
|
332 |
authors = entry.get('author', 'Unknown').replace('\n', ' ').replace('\t', ' ').strip()
|
333 |
year = entry.get('year', 'Unknown')
|
334 |
|
|
|
335 |
key = self._generate_unique_bibtex_key(entry, existing_keys)
|
336 |
entry['ID'] = key
|
337 |
existing_keys.add(key)
|
338 |
|
|
|
339 |
writer = BibTexWriter()
|
340 |
writer.indent = ' '
|
341 |
writer.comma_first = False
|
|
|
348 |
'bibtex_key': key,
|
349 |
'bibtex_entry': formatted_bibtex
|
350 |
})
|
|
|
|
|
|
|
351 |
|
352 |
+
except Exception as e:
|
353 |
+
logger.error(f"Error processing CrossRef item: {e}")
|
354 |
return papers
|
355 |
|
356 |
except aiohttp.ClientError as e:
|
|
|
362 |
await asyncio.sleep(delay)
|
363 |
|
364 |
except Exception as e:
|
365 |
+
logger.error(f"Error searching CrossRef: {e}")
|
366 |
return []
|
367 |
|
368 |
def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
|
|
|
371 |
year = entry.get('year', '')
|
372 |
authors = [a.strip() for a in author_field.split(' and ')]
|
373 |
first_author_last_name = authors[0].split(',')[0] if authors else 'unknown'
|
374 |
+
|
375 |
if entry_type == 'inbook':
|
|
|
376 |
booktitle = entry.get('booktitle', '')
|
377 |
title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
|
378 |
else:
|
|
|
379 |
title = entry.get('title', '')
|
380 |
title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
|
381 |
+
|
382 |
base_key = f"{first_author_last_name}{year}{title_word}"
|
|
|
383 |
key = base_key
|
384 |
index = 1
|
385 |
while key in existing_keys:
|
|
|
388 |
return key
|
389 |
|
390 |
async def process_text(self, text: str, num_queries: int, citations_per_query: int,
|
391 |
+
use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str]:
|
392 |
if not (use_arxiv or use_crossref):
|
393 |
return "Please select at least one source (ArXiv or CrossRef)", ""
|
394 |
|
395 |
num_queries = min(max(1, num_queries), self.config.max_queries)
|
396 |
citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
|
397 |
|
398 |
+
async def generate_queries_tool(input_data: dict):
|
399 |
+
return await self.generate_queries(input_data["text"], input_data["num_queries"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
|
401 |
+
async def search_papers_tool(input_data: dict):
|
402 |
+
queries = input_data["queries"]
|
403 |
+
papers = []
|
404 |
+
async with self.async_context as session:
|
405 |
+
search_tasks = []
|
406 |
+
for q in queries:
|
407 |
+
if input_data["use_arxiv"]:
|
408 |
+
search_tasks.append(self.search_arxiv(session, q, input_data["citations_per_query"]))
|
409 |
+
if input_data["use_crossref"]:
|
410 |
+
search_tasks.append(self.search_crossref(session, q, input_data["citations_per_query"]))
|
411 |
+
results = await asyncio.gather(*search_tasks, return_exceptions=True)
|
412 |
+
for r in results:
|
413 |
+
if not isinstance(r, Exception):
|
414 |
+
papers.extend(r)
|
415 |
+
# Deduplicate
|
416 |
+
unique_papers = []
|
417 |
+
seen_keys = set()
|
418 |
+
for p in papers:
|
419 |
+
if p['bibtex_key'] not in seen_keys:
|
420 |
+
seen_keys.add(p['bibtex_key'])
|
421 |
+
unique_papers.append(p)
|
422 |
+
return unique_papers
|
423 |
|
424 |
+
async def cite_text_tool(input_data: dict):
|
425 |
+
try:
|
426 |
+
citation_input = {
|
427 |
+
"text": input_data["text"],
|
428 |
+
"papers": json.dumps(input_data["papers"], indent=2)
|
429 |
+
}
|
430 |
+
prompt = self.citation_prompt.format(**citation_input)
|
431 |
+
response = await self.llm.apredict(prompt)
|
432 |
+
cited_text = response.strip()
|
433 |
+
|
434 |
+
# Aggregate BibTeX entries
|
435 |
+
bib_database = BibDatabase()
|
436 |
+
for p in input_data["papers"]:
|
437 |
+
if 'bibtex_entry' in p:
|
438 |
+
bib_db = bibtexparser.loads(p['bibtex_entry'])
|
439 |
+
if bib_db.entries:
|
440 |
+
bib_database.entries.append(bib_db.entries[0])
|
441 |
+
else:
|
442 |
+
logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
|
443 |
+
writer = BibTexWriter()
|
444 |
+
writer.indent = ' '
|
445 |
+
writer.comma_first = False
|
446 |
+
bibtex_entries = writer.write(bib_database).strip()
|
447 |
+
return cited_text, bibtex_entries
|
448 |
+
except Exception as e:
|
449 |
+
logger.error(f"Error inserting citations: {e}")
|
450 |
+
return input_data["text"], ""
|
451 |
+
|
452 |
+
async def agent_run(input_data: dict):
|
453 |
+
queries = await generate_queries_tool(input_data)
|
454 |
+
papers = await search_papers_tool({
|
455 |
+
"queries": queries,
|
456 |
+
"citations_per_query": input_data["citations_per_query"],
|
457 |
+
"use_arxiv": input_data["use_arxiv"],
|
458 |
+
"use_crossref": input_data["use_crossref"]
|
459 |
})
|
460 |
+
if not papers:
|
461 |
+
return input_data["text"], ""
|
462 |
+
cited_text, final_bibtex = await cite_text_tool({
|
463 |
+
"text": input_data["text"],
|
464 |
+
"papers": papers
|
465 |
+
})
|
466 |
+
return cited_text, final_bibtex
|
467 |
|
468 |
+
final_text, final_bibtex = await agent_run({
|
469 |
+
"text": text,
|
470 |
+
"num_queries": num_queries,
|
471 |
+
"citations_per_query": citations_per_query,
|
472 |
+
"use_arxiv": use_arxiv,
|
473 |
+
"use_crossref": use_crossref
|
474 |
+
})
|
475 |
+
return final_text, final_bibtex
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
|
477 |
def create_gradio_interface() -> gr.Interface:
|
|
|
478 |
async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
|
479 |
use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
|
480 |
if not api_key.strip():
|
|
|
589 |
|
590 |
with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
|
591 |
gr.HTML("""<div class="header">
|
592 |
+
<h1>📚 AutoCitation</h1>
|
593 |
+
<p>Insert citations into your academic text</p>
|
594 |
+
</div>""")
|
595 |
|
596 |
with gr.Group(elem_classes="input-group"):
|
|
|
597 |
api_key = gr.Textbox(
|
598 |
label="Gemini API Key",
|
599 |
placeholder="Enter your Gemini API key...",
|
|
|
611 |
label="Search Queries",
|
612 |
value=3,
|
613 |
minimum=1,
|
614 |
+
maximum=Config.max_queries,
|
615 |
step=1
|
616 |
)
|
617 |
with gr.Column(scale=1):
|
|
|
619 |
label="Citations per Query",
|
620 |
value=1,
|
621 |
minimum=1,
|
622 |
+
maximum=Config.max_citations_per_query,
|
623 |
step=1
|
624 |
)
|
625 |
|
|
|
657 |
show_copy_button=True
|
658 |
)
|
659 |
|
|
|
660 |
process_btn.click(
|
661 |
fn=process,
|
662 |
inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
|
|
|
666 |
return demo
|
667 |
|
668 |
if __name__ == "__main__":
|
|
|
669 |
demo = create_gradio_interface()
|
670 |
+
try:
|
671 |
+
demo.launch(server_port=7860, share=False)
|
672 |
+
except KeyboardInterrupt:
|
673 |
+
print("\nShutting down server...")
|
674 |
+
except Exception as e:
|
675 |
+
print(f"Error starting server: {str(e)}")
|