import transformers import re from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline import torch import gradio as gr import json import os import shutil import requests import pandas as pd # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" editorial_model = "PleIAs/Estienne" token_classifier = pipeline( "token-classification", model=editorial_model, aggregation_strategy="simple", device=device ) tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512) def split_text(text, max_tokens=500): # Split the text by newline characters parts = text.split("\n") chunks = [] current_chunk = "" for part in parts: # Add part to current chunk if current_chunk: temp_chunk = current_chunk + "\n" + part else: temp_chunk = part # Tokenize the temporary chunk num_tokens = len(tokenizer.tokenize(temp_chunk)) if num_tokens <= max_tokens: current_chunk = temp_chunk else: if current_chunk: chunks.append(current_chunk) current_chunk = part if current_chunk: chunks.append(current_chunk) # If no newlines were found and still exceeding max_tokens, split further if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: long_text = chunks[0] chunks = [] while len(tokenizer.tokenize(long_text)) > max_tokens: split_point = len(long_text) // 2 while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): split_point += 1 # Ensure split_point does not go out of range if split_point >= len(long_text): split_point = len(long_text) - 1 chunks.append(long_text[:split_point].strip()) long_text = long_text[split_point:].strip() if long_text: chunks.append(long_text) return chunks # Class to encapsulate the Falcon chatbot class MistralChatBot: def __init__(self, system_prompt="Le dialogue suivant est une conversation"): self.system_prompt = system_prompt def predict(self, user_message): #We drop the newlines. editorial_text = re.sub("\n", " ¶ ", user_message) # Tokenize the prompt and check if it exceeds 500 tokens num_tokens = len(tokenizer.tokenize(prompt)) if num_tokens > 500: # Split the prompt into chunks batch_prompts = split_text(prompt, max_tokens=500) else: batch_prompts = [prompt] out = token_classifier(batch_prompts) out = "".join(out) generated_text = '