Spaces:
Sleeping
Sleeping
import transformers | |
import re | |
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline | |
import torch | |
import gradio as gr | |
import json | |
import os | |
import shutil | |
import requests | |
import pandas as pd | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
editorial_model = "PleIAs/Bibliography-Formatter" | |
token_classifier = pipeline( | |
"token-classification", model=editorial_model, aggregation_strategy="simple", device=device | |
) | |
tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512) | |
css = """ | |
<style> | |
.manuscript { | |
display: flex; | |
margin-bottom: 10px; | |
align-items: baseline; | |
} | |
.annotation { | |
width: 15%; | |
padding-right: 20px; | |
color: grey !important; | |
font-style: italic; | |
text-align: right; | |
} | |
.content { | |
width: 80%; | |
} | |
h2 { | |
margin: 0; | |
font-size: 1.5em; | |
} | |
.title-content h2 { | |
font-weight: bold; | |
} | |
.bibliography-content { | |
color:darkgreen !important; | |
margin-top: -5px; /* Adjust if needed to align with annotation */ | |
} | |
.paratext-content { | |
color:#a4a4a4 !important; | |
margin-top: -5px; /* Adjust if needed to align with annotation */ | |
} | |
</style> | |
""" | |
# Preprocess the 'word' column | |
def preprocess_text(text): | |
# Remove HTML tags | |
text = re.sub(r'<[^>]+>', '', text) | |
# Replace newlines with spaces | |
text = re.sub(r'\n', ' ', text) | |
# Replace multiple spaces with a single space | |
text = re.sub(r'\s+', ' ', text) | |
# Strip leading and trailing whitespace | |
return text.strip() | |
def split_text(text, max_tokens=500): | |
# Split the text by newline characters | |
parts = text.split("\n") | |
chunks = [] | |
current_chunk = "" | |
for part in parts: | |
# Add part to current chunk | |
if current_chunk: | |
temp_chunk = current_chunk + "\n" + part | |
else: | |
temp_chunk = part | |
# Tokenize the temporary chunk | |
num_tokens = len(tokenizer.tokenize(temp_chunk)) | |
if num_tokens <= max_tokens: | |
current_chunk = temp_chunk | |
else: | |
if current_chunk: | |
chunks.append(current_chunk) | |
current_chunk = part | |
if current_chunk: | |
chunks.append(current_chunk) | |
# If no newlines were found and still exceeding max_tokens, split further | |
if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: | |
long_text = chunks[0] | |
chunks = [] | |
while len(tokenizer.tokenize(long_text)) > max_tokens: | |
split_point = len(long_text) // 2 | |
while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): | |
split_point += 1 | |
# Ensure split_point does not go out of range | |
if split_point >= len(long_text): | |
split_point = len(long_text) - 1 | |
chunks.append(long_text[:split_point].strip()) | |
long_text = long_text[split_point:].strip() | |
if long_text: | |
chunks.append(long_text) | |
return chunks | |
def extract_year(text): | |
year_match = re.search(r'\b(\d{4})\b', text) | |
return year_match.group(1) if year_match else None | |
def create_bibtex_entry(data): | |
# Determine the entry type | |
if 'Journal' in data: | |
entry_type = 'article' | |
elif 'Booktitle' in data: | |
entry_type = 'incollection' | |
else: | |
entry_type = 'book' | |
# Extract year from 'None' if it exists | |
none_content = data.pop('None', '') | |
year = extract_year(none_content) | |
if year and 'Year' not in data: | |
data['Year'] = year | |
# Create BibTeX ID | |
author_words = data.get('Author', '').split() | |
first_author = author_words[0] if author_words else 'Unknown' | |
bibtex_id = f"{first_author}{year}" if year else first_author | |
bibtex = f"@{entry_type}{{{bibtex_id},\n" | |
for key, value in data.items(): | |
if value.strip(): | |
bibtex += f" {key.lower()} = {{{value.strip()}}},\n" | |
bibtex = bibtex.rstrip(',\n') + "\n}" | |
return bibtex | |
def transform_chunks(marianne_segmentation): | |
marianne_segmentation = pd.DataFrame(marianne_segmentation) | |
marianne_segmentation = marianne_segmentation[marianne_segmentation['entity_group'] != 'separator'] | |
marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).str.replace('¶', '\n', regex=False) | |
marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).apply(preprocess_text) | |
marianne_segmentation = marianne_segmentation[marianne_segmentation['word'].notna() & (marianne_segmentation['word'] != '') & (marianne_segmentation['word'] != ' ')] | |
html_output = [] | |
bibtex_data = {} | |
current_entity = None | |
for _, row in marianne_segmentation.iterrows(): | |
entity_group = row['entity_group'] | |
result_entity = "[" + entity_group.capitalize() + "]" | |
word = row['word'] | |
if entity_group != 'None': | |
if entity_group in bibtex_data: | |
bibtex_data[entity_group] += ' ' + word | |
else: | |
bibtex_data[entity_group] = word | |
current_entity = entity_group | |
else: | |
if current_entity: | |
bibtex_data[current_entity] += ' ' + word | |
else: | |
bibtex_data['None'] = bibtex_data.get('None', '') + ' ' + word | |
html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content">{word}</div></div>') | |
bibtex_entry = create_bibtex_entry(bibtex_data) | |
final_html = '\n'.join(html_output) | |
return final_html, bibtex_entry | |
# Class to encapsulate the Falcon chatbot | |
class MistralChatBot: | |
def __init__(self, system_prompt="Le dialogue suivant est une conversation"): | |
self.system_prompt = system_prompt | |
def predict(self, user_message): | |
editorial_text = re.sub("\n", " ¶ ", user_message) | |
num_tokens = len(tokenizer.tokenize(editorial_text)) | |
if num_tokens > 500: | |
batch_prompts = split_text(editorial_text, max_tokens=500) | |
else: | |
batch_prompts = [editorial_text] | |
out = token_classifier(batch_prompts) | |
classified_list = [] | |
for classification in out: | |
df = pd.DataFrame(classification) | |
classified_list.append(df) | |
classified_list = pd.concat(classified_list) | |
# Debugging: Print the classified list | |
print("Classified List:") | |
print(classified_list) | |
html_output, bibtex_entry = transform_chunks(classified_list) | |
# Debugging: Print the outputs | |
print("HTML Output:") | |
print(html_output) | |
print("BibTeX Entry:") | |
print(bibtex_entry) | |
generated_text = f'{css}<h2 style="text-align:center">Edited text</h2>\n<div class="generation">{html_output}</div>' | |
return generated_text, bibtex_entry | |
# Create the Falcon chatbot instance | |
mistral_bot = MistralChatBot() | |
# Define the Gradio interface | |
title = "Éditorialisation" | |
description = "Un outil expérimental d'identification de la structure du texte à partir d'un encoder (Deberta)" | |
examples = [ | |
[ | |
"Qui peut bénéficier de l'AIP?", # user_message | |
0.7 # temperature | |
] | |
] | |
demo = gr.Blocks() | |
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: | |
gr.HTML("""<h1 style="text-align:center">Reversed Zotero</h1>""") | |
text_input = gr.Textbox(label="Your text", type="text", lines=5) | |
text_button = gr.Button("Extract a structured bibtex") | |
text_output = gr.HTML(label="Metadata") | |
bibtex_output = gr.Textbox(label="BibTeX Entry", lines=10) | |
text_button.click(mistral_bot.predict, inputs=text_input, outputs=[text_output, bibtex_output]) | |
if __name__ == "__main__": | |
demo.queue().launch() |