harshildarji commited on
Commit
bdcb2c9
·
verified ·
1 Parent(s): 4635bc5

Add selective anonymization, DE option for UI

Browse files
Files changed (1) hide show
  1. app.py +102 -40
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import re
2
- import os
3
  import warnings
4
 
5
  import matplotlib.colors as mcolors
@@ -114,13 +113,53 @@ st.markdown(
114
  border-radius: 3px;
115
  padding: 2px 4px;
116
  }
 
 
 
 
 
 
117
  </style>
118
  """,
119
  unsafe_allow_html=True,
120
  )
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # Initialization for German Legal NER
123
- tkn = os.getenv("tkn")
124
  tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraBERT", use_auth_token=tkn)
125
  model = AutoModelForTokenClassification.from_pretrained(
126
  "harshildarji/JuraBERT", use_auth_token=tkn
@@ -152,14 +191,14 @@ classes = {
152
  ner_labels = list(classes.keys())
153
 
154
 
155
- # Function to generate a list of colors for visualization
156
  def generate_colors(num_colors):
157
  cm = plt.get_cmap("tab20")
158
  colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)]
159
  return colors
160
 
161
 
162
- # Function to color substrings based on NER results
163
  def color_substrings(input_string, model_output):
164
  colors = generate_colors(len(ner_labels))
165
  label_to_color = {
@@ -173,29 +212,41 @@ def color_substrings(input_string, model_output):
173
  start, end, label = entity["start"], entity["end"], entity["label"]
174
  html_output += input_string[last_end:start]
175
  tooltip = classes.get(label, "")
176
- html_output += f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
 
 
 
177
  last_end = end
178
 
179
  html_output += input_string[last_end:]
180
-
181
  return html_output
182
 
183
 
184
- # Function to anonymize entities
185
- def anonymize_text(input_string, model_output):
186
  anonymized_text = ""
187
  last_end = 0
 
 
 
 
188
 
189
  for entity in sorted(model_output, key=lambda x: x["start"]):
190
  start, end, label = entity["start"], entity["end"], entity["label"]
191
  anonymized_text += input_string[last_end:start]
192
- anonymized_text += (
193
- f'<span class="anonymized">[{classes.get(label, label)}]</span>'
194
- )
 
 
 
 
 
 
 
195
  last_end = end
196
 
197
  anonymized_text += input_string[last_end:]
198
-
199
  return anonymized_text
200
 
201
 
@@ -209,7 +260,6 @@ def merge_entities(ner_results):
209
  token_start, token_end = token["start"], token["end"]
210
  token_word = token["word"].replace("##", "") # Remove subword prefixes
211
 
212
- # Start a new entity if necessary
213
  if (
214
  tag.startswith("B-")
215
  or current_entity is None
@@ -228,11 +278,9 @@ def merge_entities(ner_results):
228
  and current_entity
229
  and current_entity["label"] == entity_type
230
  ):
231
- # Extend the current entity
232
  current_entity["end"] = token_end
233
  current_entity["word"] += token_word
234
  else:
235
- # Handle misclassifications or gaps in tokens
236
  if (
237
  current_entity
238
  and token_start == current_entity["end"]
@@ -241,7 +289,6 @@ def merge_entities(ner_results):
241
  current_entity["end"] = token_end
242
  current_entity["word"] += token_word
243
  else:
244
- # Treat it as a new entity if the above conditions aren't met
245
  if current_entity:
246
  merged_entities.append(current_entity)
247
  current_entity = {
@@ -251,32 +298,52 @@ def merge_entities(ner_results):
251
  "word": token_word,
252
  }
253
 
254
- # Append the last entity
255
  if current_entity:
256
  merged_entities.append(current_entity)
257
-
258
  return merged_entities
259
 
260
 
261
- st.title("Legal NER")
262
- st.markdown("<hr>", unsafe_allow_html=True)
263
-
264
- uploaded_file = st.file_uploader("Upload a .txt file", type="txt")
265
-
266
 
267
  if uploaded_file is not None:
268
  try:
269
  raw_content = uploaded_file.read()
270
-
271
  detected = detect(raw_content)
272
  encoding = detected["encoding"]
273
-
274
  if encoding is None:
275
  raise ValueError("Unable to detect file encoding.")
276
 
277
  lines = raw_content.decode(encoding).splitlines()
278
 
279
- anonymize_mode = st.checkbox("Anonymize")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  st.markdown(
281
  "<hr style='margin-top: 10px; margin-bottom: 20px;'>",
282
  unsafe_allow_html=True,
@@ -285,13 +352,12 @@ if uploaded_file is not None:
285
  anonymized_lines = []
286
  displayed_lines = []
287
 
288
- for line_number, line in enumerate(lines, start=1):
289
  if line.strip():
290
- results = ner(line)
291
- merged_results = merge_entities(results)
292
-
293
  if anonymize_mode:
294
- anonymized_text = anonymize_text(line, merged_results)
 
 
295
  displayed_lines.append(anonymized_text)
296
  plain_text = re.sub(r"<.*?>", "", anonymized_text)
297
  anonymized_lines.append(plain_text.strip())
@@ -299,31 +365,27 @@ if uploaded_file is not None:
299
  colored_html = color_substrings(line, merged_results)
300
  st.markdown(f"{colored_html}", unsafe_allow_html=True)
301
  else:
302
- displayed_lines.append("<br>")
303
  anonymized_lines.append("")
304
 
305
  if anonymize_mode:
306
  original_file_name = uploaded_file.name
307
  download_file_name = f"Anon_{original_file_name}"
308
-
309
  anonymized_content = "\n".join(anonymized_lines)
310
-
311
  for displayed_line in displayed_lines:
312
  st.markdown(f"{displayed_line}", unsafe_allow_html=True)
313
-
314
  st.markdown("<hr>", unsafe_allow_html=True)
315
  st.download_button(
316
- label="Download Anonymized Text",
317
  data=anonymized_content,
318
  file_name=download_file_name,
319
  mime="text/plain",
320
  )
321
-
322
- if not anonymize_mode:
323
  st.markdown(
324
- '<div class="tip"><strong>Tip:</strong> Hover over the colored words to see its class.</div>',
325
  unsafe_allow_html=True,
326
  )
327
 
328
  except Exception as e:
329
- st.error(f"An error occurred while processing the file: {e}")
 
1
  import re
 
2
  import warnings
3
 
4
  import matplotlib.colors as mcolors
 
113
  border-radius: 3px;
114
  padding: 2px 4px;
115
  }
116
+ #language-container {
117
+ position: fixed;
118
+ top: 10px;
119
+ right: 10px;
120
+ z-index: 1000;
121
+ }
122
  </style>
123
  """,
124
  unsafe_allow_html=True,
125
  )
126
 
127
+ # UI text for English and German.
128
+ ui_text = {
129
+ "EN": {
130
+ "title": "Legal NER",
131
+ "upload": "Upload a .txt file",
132
+ "anonymize": "Anonymize",
133
+ "select_entities": "Entity types to anonymize:",
134
+ "download": "Download Anonymized Text",
135
+ "tip": "Tip: Hover over the colored words to see its class.",
136
+ "error": "An error occurred while processing the file: ",
137
+ },
138
+ "DE": {
139
+ "title": "Juristische NER",
140
+ "upload": "Lade eine .txt-Datei hoch",
141
+ "anonymize": "Anonymisieren",
142
+ "select_entities": "Entitätstypen zur Anonymisierung:",
143
+ "download": "Anonymisierten Text herunterladen",
144
+ "tip": "Tipp: Fahre mit der Maus über die farbigen Wörter, um deren Klasse zu sehen.",
145
+ "error": "Beim Verarbeiten der Datei ist ein Fehler aufgetreten: ",
146
+ },
147
+ }
148
+
149
+ col1, col2 = st.columns([4, 1])
150
+ with col2:
151
+ lang = st.radio(
152
+ "",
153
+ options=["EN", "DE"],
154
+ horizontal=True,
155
+ label_visibility="collapsed",
156
+ key="language_selector",
157
+ )
158
+ with col1:
159
+ st.title(ui_text[lang]["title"])
160
+
161
  # Initialization for German Legal NER
162
+ tkn = open("./token").read()
163
  tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraBERT", use_auth_token=tkn)
164
  model = AutoModelForTokenClassification.from_pretrained(
165
  "harshildarji/JuraBERT", use_auth_token=tkn
 
191
  ner_labels = list(classes.keys())
192
 
193
 
194
+ # Generate a list of colors for visualization
195
  def generate_colors(num_colors):
196
  cm = plt.get_cmap("tab20")
197
  colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)]
198
  return colors
199
 
200
 
201
+ # Color substrings based on NER results
202
  def color_substrings(input_string, model_output):
203
  colors = generate_colors(len(ner_labels))
204
  label_to_color = {
 
212
  start, end, label = entity["start"], entity["end"], entity["label"]
213
  html_output += input_string[last_end:start]
214
  tooltip = classes.get(label, "")
215
+ html_output += (
216
+ f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">'
217
+ f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
218
+ )
219
  last_end = end
220
 
221
  html_output += input_string[last_end:]
 
222
  return html_output
223
 
224
 
225
+ # Selectively anonymize entities
226
+ def anonymize_text(input_string, model_output, selected_entities=None):
227
  anonymized_text = ""
228
  last_end = 0
229
+ colors = generate_colors(len(ner_labels))
230
+ label_to_color = {
231
+ label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
232
+ }
233
 
234
  for entity in sorted(model_output, key=lambda x: x["start"]):
235
  start, end, label = entity["start"], entity["end"], entity["label"]
236
  anonymized_text += input_string[last_end:start]
237
+ if selected_entities is None or label in selected_entities:
238
+ anonymized_text += (
239
+ f'<span class="anonymized">[{classes.get(label, label)}]</span>'
240
+ )
241
+ else:
242
+ tooltip = classes.get(label, "")
243
+ anonymized_text += (
244
+ f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">'
245
+ f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
246
+ )
247
  last_end = end
248
 
249
  anonymized_text += input_string[last_end:]
 
250
  return anonymized_text
251
 
252
 
 
260
  token_start, token_end = token["start"], token["end"]
261
  token_word = token["word"].replace("##", "") # Remove subword prefixes
262
 
 
263
  if (
264
  tag.startswith("B-")
265
  or current_entity is None
 
278
  and current_entity
279
  and current_entity["label"] == entity_type
280
  ):
 
281
  current_entity["end"] = token_end
282
  current_entity["word"] += token_word
283
  else:
 
284
  if (
285
  current_entity
286
  and token_start == current_entity["end"]
 
289
  current_entity["end"] = token_end
290
  current_entity["word"] += token_word
291
  else:
 
292
  if current_entity:
293
  merged_entities.append(current_entity)
294
  current_entity = {
 
298
  "word": token_word,
299
  }
300
 
 
301
  if current_entity:
302
  merged_entities.append(current_entity)
 
303
  return merged_entities
304
 
305
 
306
+ uploaded_file = st.file_uploader(ui_text[lang]["upload"], type="txt")
 
 
 
 
307
 
308
  if uploaded_file is not None:
309
  try:
310
  raw_content = uploaded_file.read()
 
311
  detected = detect(raw_content)
312
  encoding = detected["encoding"]
 
313
  if encoding is None:
314
  raise ValueError("Unable to detect file encoding.")
315
 
316
  lines = raw_content.decode(encoding).splitlines()
317
 
318
+ line_results = []
319
+ for line in lines:
320
+ if line.strip():
321
+ results = ner(line)
322
+ merged_results = merge_entities(results)
323
+ line_results.append(merged_results)
324
+ else:
325
+ line_results.append([])
326
+
327
+ anonymize_mode = st.checkbox(ui_text[lang]["anonymize"])
328
+
329
+ selected_entities = None
330
+ if anonymize_mode:
331
+ detected_entity_tags = set()
332
+ for merged_results in line_results:
333
+ for entity in merged_results:
334
+ detected_entity_tags.add(entity["label"])
335
+
336
+ inverse_classes = {v: k for k, v in classes.items()}
337
+ detected_options = sorted([classes[tag] for tag in detected_entity_tags])
338
+ selected_options = st.multiselect(
339
+ ui_text[lang]["select_entities"],
340
+ options=detected_options,
341
+ default=detected_options,
342
+ )
343
+ selected_entities = [
344
+ inverse_classes[options] for options in selected_options
345
+ ]
346
+
347
  st.markdown(
348
  "<hr style='margin-top: 10px; margin-bottom: 20px;'>",
349
  unsafe_allow_html=True,
 
352
  anonymized_lines = []
353
  displayed_lines = []
354
 
355
+ for line, merged_results in zip(lines, line_results):
356
  if line.strip():
 
 
 
357
  if anonymize_mode:
358
+ anonymized_text = anonymize_text(
359
+ line, merged_results, selected_entities=selected_entities
360
+ )
361
  displayed_lines.append(anonymized_text)
362
  plain_text = re.sub(r"<.*?>", "", anonymized_text)
363
  anonymized_lines.append(plain_text.strip())
 
365
  colored_html = color_substrings(line, merged_results)
366
  st.markdown(f"{colored_html}", unsafe_allow_html=True)
367
  else:
368
+ # displayed_lines.append("<br>")
369
  anonymized_lines.append("")
370
 
371
  if anonymize_mode:
372
  original_file_name = uploaded_file.name
373
  download_file_name = f"Anon_{original_file_name}"
 
374
  anonymized_content = "\n".join(anonymized_lines)
 
375
  for displayed_line in displayed_lines:
376
  st.markdown(f"{displayed_line}", unsafe_allow_html=True)
 
377
  st.markdown("<hr>", unsafe_allow_html=True)
378
  st.download_button(
379
+ label=ui_text[lang]["download"],
380
  data=anonymized_content,
381
  file_name=download_file_name,
382
  mime="text/plain",
383
  )
384
+ else:
 
385
  st.markdown(
386
+ f'<div class="tip"><strong>{ui_text[lang]["tip"]}</strong></div>',
387
  unsafe_allow_html=True,
388
  )
389
 
390
  except Exception as e:
391
+ st.error(f"{ui_text[lang]['error']}{e}")