Pclanglais commited on
Commit
750020e
·
verified ·
1 Parent(s): 100e33a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import re
3
+ from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
4
+ from vllm import LLM, SamplingParams
5
+ import torch
6
+ import gradio as gr
7
+ import json
8
+ import os
9
+ import shutil
10
+ import requests
11
+ import chromadb
12
+ import difflib
13
+ import pandas as pd
14
+ from chromadb.config import Settings
15
+ from chromadb.utils import embedding_functions
16
+
17
+ # Define the device
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ model_checkpoint = "PleIAs/Estienne"
21
+ token_classifier = pipeline(
22
+ "token-classification", model=editorial_model, aggregation_strategy="simple", device=device
23
+ )
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512)
26
+
27
+
28
+ def split_text(text, max_tokens=500):
29
+ # Split the text by newline characters
30
+ parts = text.split("\n")
31
+ chunks = []
32
+ current_chunk = ""
33
+
34
+ for part in parts:
35
+ # Add part to current chunk
36
+ if current_chunk:
37
+ temp_chunk = current_chunk + "\n" + part
38
+ else:
39
+ temp_chunk = part
40
+
41
+ # Tokenize the temporary chunk
42
+ num_tokens = len(tokenizer.tokenize(temp_chunk))
43
+
44
+ if num_tokens <= max_tokens:
45
+ current_chunk = temp_chunk
46
+ else:
47
+ if current_chunk:
48
+ chunks.append(current_chunk)
49
+ current_chunk = part
50
+
51
+ if current_chunk:
52
+ chunks.append(current_chunk)
53
+
54
+ # If no newlines were found and still exceeding max_tokens, split further
55
+ if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens:
56
+ long_text = chunks[0]
57
+ chunks = []
58
+ while len(tokenizer.tokenize(long_text)) > max_tokens:
59
+ split_point = len(long_text) // 2
60
+ while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]):
61
+ split_point += 1
62
+ # Ensure split_point does not go out of range
63
+ if split_point >= len(long_text):
64
+ split_point = len(long_text) - 1
65
+ chunks.append(long_text[:split_point].strip())
66
+ long_text = long_text[split_point:].strip()
67
+ if long_text:
68
+ chunks.append(long_text)
69
+
70
+ return chunks
71
+
72
+
73
+ #Curtesy of claude
74
+ def generate_html_diff(old_text, new_text):
75
+ d = difflib.Differ()
76
+ diff = list(d.compare(old_text.split(), new_text.split()))
77
+
78
+ html_diff = []
79
+ for word in diff:
80
+ if word.startswith(' '):
81
+ html_diff.append(word[2:])
82
+ elif word.startswith('+ '):
83
+ html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>')
84
+ # We're not adding anything for words that start with '- '
85
+
86
+ return ' '.join(html_diff)
87
+
88
+ # Class to encapsulate the Falcon chatbot
89
+ class MistralChatBot:
90
+ def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
91
+ self.system_prompt = system_prompt
92
+
93
+ def predict(self, user_message):
94
+ #We drop the newlines.
95
+ editorial_text = re.sub("\n", " ¶ ", user_message)
96
+
97
+ # Tokenize the prompt and check if it exceeds 500 tokens
98
+ num_tokens = len(tokenizer.tokenize(prompt))
99
+
100
+ if num_tokens > 500:
101
+ # Split the prompt into chunks
102
+ batch_prompts = split_text(prompt, max_tokens=500)
103
+ else:
104
+ batch_prompts = [prompt]
105
+
106
+ out = token_classifier(batch_prompts)
107
+ out = "".join(out)
108
+ generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + html_diff + "</div>"
109
+ return generated_text
110
+
111
+ # Create the Falcon chatbot instance
112
+ mistral_bot = MistralChatBot()
113
+
114
+ # Define the Gradio interface
115
+ title = "Éditorialisation"
116
+ description = "Un outil expérimental d'identification de la structure du texte à partir d'un encoder (Deberta)"
117
+ examples = [
118
+ [
119
+ "Qui peut bénéficier de l'AIP?", # user_message
120
+ 0.7 # temperature
121
+ ]
122
+ ]
123
+
124
+ additional_inputs=[
125
+ gr.Slider(
126
+ label="Température",
127
+ value=0.2, # Default value
128
+ minimum=0.05,
129
+ maximum=1.0,
130
+ step=0.05,
131
+ interactive=True,
132
+ info="Des valeurs plus élevées donne plus de créativité, mais aussi d'étrangeté",
133
+ ),
134
+ ]
135
+
136
+ demo = gr.Blocks()
137
+
138
+ with gr.Blocks(theme='JohnSmith9982/small_and_pretty', css=css) as demo:
139
+ gr.HTML("""<h1 style="text-align:center">Correction d'OCR</h1>""")
140
+ text_input = gr.Textbox(label="Votre texte.", type="text", lines=1)
141
+ text_button = gr.Button("Identifier les structures éditoriales")
142
+ text_output = gr.HTML(label="Le texte corrigé")
143
+ text_button.click(mistral_bot.predict, inputs=text_input, outputs=[text_output])
144
+
145
+ if __name__ == "__main__":
146
+ demo.queue().launch()