harshildarji commited on
Commit
da37e40
·
verified ·
1 Parent(s): 1b901a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -351
app.py CHANGED
@@ -1,9 +1,8 @@
1
- import re
2
  import os
3
- import warnings
 
4
 
5
- import matplotlib.colors as mcolors
6
- import matplotlib.pyplot as plt
7
  import streamlit as st
8
  from charset_normalizer import detect
9
  from transformers import (
@@ -13,153 +12,61 @@ from transformers import (
13
  pipeline,
14
  )
15
 
16
- warnings.simplefilter(action="ignore", category=Warning)
17
  logging.set_verbosity(logging.ERROR)
18
 
19
- st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide")
20
-
21
  st.markdown(
22
  """
23
- <style>
24
- body {
25
- font-family: 'Poppins', sans-serif;
26
- background-color: #f4f4f8;
27
- }
28
- .header {
29
- background-color: rgba(220, 219, 219, 0.25);
30
- color: #000;
31
- padding: 5px 0;
32
- text-align: center;
33
- border-radius: 7px;
34
- margin-bottom: 13px;
35
- border-bottom: 2px solid #333;
36
- }
37
- .container {
38
- background-color: #fff;
39
- padding: 30px;
40
- border-radius: 10px;
41
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
42
- width: 100%;
43
- max-width: 1000px;
44
- margin: 0 auto;
45
- position: absolute;
46
- top: 50%;
47
- left: 50%;
48
- transform: translate(-50%, -50%);
49
- }
50
- .btn-primary {
51
- background-color: #5477d1;
52
- border: none;
53
- transition: background-color 0.3s, transform 0.2s;
54
- border-radius: 25px;
55
- box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
56
- }
57
- .btn-primary:hover {
58
- background-color: #4c6cbe;
59
- transform: translateY(-1px);
60
- }
61
- h2 {
62
- font-weight: 600;
63
- font-size: 24px;
64
- margin-bottom: 20px;
65
- }
66
- label {
67
- font-weight: 500;
68
- }
69
- .tip {
70
- background-color: rgba(180, 47, 109, 0.25);
71
- padding: 7px;
72
- border-radius: 7px;
73
- display: inline-block;
74
- margin-top: 15px;
75
- margin-bottom: 15px;
76
- }
77
- .sec {
78
- background-color: rgba(220, 219, 219, 0.10);
79
- padding: 7px;
80
- border-radius: 5px;
81
- display: inline-block;
82
- margin-top: 15px;
83
- margin-bottom: 15px;
84
- }
85
- .tooltip {
86
- position: relative;
87
- display: inline-block;
88
- cursor: pointer;
89
- }
90
- .tooltip .tooltiptext {
91
- visibility: hidden;
92
- width: 120px;
93
- background-color: #6c757d;
94
- color: #fff;
95
- text-align: center;
96
- border-radius: 3px;
97
- padding: 3px;
98
- position: absolute;
99
- z-index: 1;
100
- bottom: 125%;
101
- left: 50%;
102
- margin-left: -60px;
103
- opacity: 0;
104
- transition: opacity 0.3s;
105
- }
106
- .tooltip:hover .tooltiptext {
107
- visibility: visible;
108
- opacity: 1;
109
- }
110
- .anonymized {
111
- background-color: #ffcccb;
112
- color: #000;
113
- font-weight: bold;
114
- border-radius: 3px;
115
- padding: 2px 4px;
116
- }
117
- #language-container {
118
- position: fixed;
119
- top: 10px;
120
- right: 10px;
121
- z-index: 1000;
122
- }
123
- </style>
124
  """,
125
  unsafe_allow_html=True,
126
  )
127
 
128
- # UI text for English and German.
129
- ui_text = {
130
- "EN": {
131
- "title": "Legal NER",
132
- "upload": "Upload a .txt file",
133
- "anonymize": "Anonymize",
134
- "select_entities": "Entity types to anonymize:",
135
- "download": "Download Anonymized Text",
136
- "tip": "Tip: Hover over the colored words to see its class.",
137
- "error": "An error occurred while processing the file: ",
138
- },
139
- "DE": {
140
- "title": "Juristische NER",
141
- "upload": "Lade eine .txt-Datei hoch",
142
- "anonymize": "Anonymisieren",
143
- "select_entities": "Entitätstypen zur Anonymisierung:",
144
- "download": "Anonymisierten Text herunterladen",
145
- "tip": "Tipp: Fahre mit der Maus über die farbigen Wörter, um deren Klasse zu sehen.",
146
- "error": "Beim Verarbeiten der Datei ist ein Fehler aufgetreten: ",
147
- },
148
- }
149
-
150
- col1, col2 = st.columns([4, 1])
151
- with col2:
152
- lang = st.radio(
153
- "Language:",
154
- options=["EN", "DE"],
155
- horizontal=True,
156
- label_visibility="hidden",
157
- key="language_selector",
158
- )
159
- with col1:
160
- st.title(ui_text[lang]["title"])
161
-
162
- # Initialization for German Legal NER
163
  tkn = os.getenv("tkn")
164
  tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER", use_auth_token=tkn)
165
  model = AutoModelForTokenClassification.from_pretrained(
@@ -167,8 +74,8 @@ model = AutoModelForTokenClassification.from_pretrained(
167
  )
168
  ner = pipeline("ner", model=model, tokenizer=tokenizer)
169
 
170
- # Define class labels for the model
171
- classes = {
172
  "AN": "Lawyer",
173
  "EUN": "European legal norm",
174
  "GRT": "Court",
@@ -189,223 +96,124 @@ classes = {
189
  "VS": "Regulation",
190
  "VT": "Contract",
191
  }
192
- ner_labels = list(classes.keys())
193
 
194
 
195
- # Generate a list of colors for visualization
196
- def generate_colors(num_colors):
197
- cm = plt.get_cmap("tab20")
198
- colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)]
199
- return colors
 
 
 
 
200
 
201
 
202
- # Color substrings based on NER results
203
- def color_substrings(input_string, model_output):
204
- colors = generate_colors(len(ner_labels))
205
- label_to_color = {
206
- label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
207
- }
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  last_end = 0
210
- html_output = ""
211
-
212
- for entity in sorted(model_output, key=lambda x: x["start"]):
213
- start, end, label = entity["start"], entity["end"], entity["label"]
214
- html_output += input_string[last_end:start]
215
- tooltip = classes.get(label, "")
216
- html_output += (
217
- f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">'
218
- f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
 
 
 
 
 
 
 
219
  )
 
220
  last_end = end
221
 
222
- html_output += input_string[last_end:]
223
- return html_output
224
-
225
-
226
- # Selectively anonymize entities
227
- def anonymize_text(input_string, model_output, selected_entities=None):
228
- merged_model_output = []
229
- sorted_entities = sorted(model_output, key=lambda x: x["start"])
230
- if sorted_entities:
231
- current = sorted_entities[0]
232
- for entity in sorted_entities[1:]:
233
- if (
234
- entity["label"] == current["label"]
235
- and input_string[current["end"] : entity["start"]].strip() == ""
236
- ):
237
- current["end"] = entity["end"]
238
- current["word"] = input_string[current["start"] : current["end"]]
239
- else:
240
- merged_model_output.append(current)
241
- current = entity
242
- merged_model_output.append(current)
243
- else:
244
- merged_model_output = sorted_entities
245
 
246
- anonymized_text = ""
247
- last_end = 0
248
- colors = generate_colors(len(ner_labels))
249
- label_to_color = {
250
- label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
251
- }
252
-
253
- for entity in merged_model_output:
254
- start, end, label = entity["start"], entity["end"], entity["label"]
255
- anonymized_text += input_string[last_end:start]
256
- if selected_entities is None or label in selected_entities:
257
- anonymized_text += (
258
- f'<span class="anonymized">[{classes.get(label, label)}]</span>'
259
- )
260
- else:
261
- tooltip = classes.get(label, "")
262
- anonymized_text += (
263
- f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">'
264
- f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
265
- )
266
- last_end = end
267
 
268
- anonymized_text += input_string[last_end:]
269
- return anonymized_text
270
-
271
-
272
- def merge_entities(ner_results):
273
- merged_entities = []
274
- current_entity = None
275
-
276
- for token in ner_results:
277
- tag = token["entity"]
278
- entity_type = tag.split("-")[-1] if "-" in tag else tag
279
- token_start, token_end = token["start"], token["end"]
280
- token_word = token["word"].replace("##", "") # Remove subword prefixes
281
-
282
- if (
283
- tag.startswith("B-")
284
- or current_entity is None
285
- or current_entity["label"] != entity_type
286
- ):
287
- if current_entity:
288
- merged_entities.append(current_entity)
289
- current_entity = {
290
- "start": token_start,
291
- "end": token_end,
292
- "label": entity_type,
293
- "word": token_word,
294
- }
295
- elif (
296
- tag.startswith("I-")
297
- and current_entity
298
- and current_entity["label"] == entity_type
299
- ):
300
- current_entity["end"] = token_end
301
- current_entity["word"] += token_word
302
- else:
303
- if (
304
- current_entity
305
- and token_start == current_entity["end"]
306
- and current_entity["label"] == entity_type
307
- ):
308
- current_entity["end"] = token_end
309
- current_entity["word"] += token_word
310
- else:
311
- if current_entity:
312
- merged_entities.append(current_entity)
313
- current_entity = {
314
- "start": token_start,
315
- "end": token_end,
316
- "label": entity_type,
317
- "word": token_word,
318
- }
319
-
320
- if current_entity:
321
- merged_entities.append(current_entity)
322
- return merged_entities
323
-
324
-
325
- uploaded_file = st.file_uploader(ui_text[lang]["upload"], type="txt")
326
-
327
- if uploaded_file is not None:
328
- try:
329
- raw_content = uploaded_file.read()
330
- detected = detect(raw_content)
331
- encoding = detected["encoding"]
332
- if encoding is None:
333
- raise ValueError("Unable to detect file encoding.")
334
-
335
- lines = raw_content.decode(encoding).splitlines()
336
-
337
- line_results = []
338
- for line in lines:
339
- if line.strip():
340
- results = ner(line)
341
- merged_results = merge_entities(results)
342
- line_results.append(merged_results)
343
- else:
344
- line_results.append([])
345
-
346
- anonymize_mode = st.checkbox(ui_text[lang]["anonymize"])
347
-
348
- selected_entities = None
349
- if anonymize_mode:
350
- detected_entity_tags = set()
351
- for merged_results in line_results:
352
- for entity in merged_results:
353
- detected_entity_tags.add(entity["label"])
354
-
355
- inverse_classes = {v: k for k, v in classes.items()}
356
- detected_options = sorted([classes[tag] for tag in detected_entity_tags])
357
- selected_options = st.multiselect(
358
- ui_text[lang]["select_entities"],
359
- options=detected_options,
360
- default=detected_options,
361
- )
362
- selected_entities = [
363
- inverse_classes[options] for options in selected_options
364
- ]
365
-
366
- st.markdown(
367
- "<hr style='margin-top: 10px; margin-bottom: 20px;'>",
368
- unsafe_allow_html=True,
369
- )
370
 
371
- anonymized_lines = []
372
- displayed_lines = []
373
-
374
- for line, merged_results in zip(lines, line_results):
375
- if line.strip():
376
- if anonymize_mode:
377
- anonymized_text = anonymize_text(
378
- line, merged_results, selected_entities=selected_entities
379
- )
380
- displayed_lines.append(anonymized_text)
381
- plain_text = re.sub(r"<.*?>", "", anonymized_text)
382
- anonymized_lines.append(plain_text.strip())
383
- else:
384
- colored_html = color_substrings(line, merged_results)
385
- st.markdown(f"{colored_html}", unsafe_allow_html=True)
386
- else:
387
- # displayed_lines.append("<br>")
388
- anonymized_lines.append("")
389
-
390
- if anonymize_mode:
391
- original_file_name = uploaded_file.name
392
- download_file_name = f"Anon_{original_file_name}"
393
- anonymized_content = "\n".join(anonymized_lines)
394
- for displayed_line in displayed_lines:
395
- st.markdown(f"{displayed_line}", unsafe_allow_html=True)
396
- st.markdown("<hr>", unsafe_allow_html=True)
397
- st.download_button(
398
- label=ui_text[lang]["download"],
399
- data=anonymized_content,
400
- file_name=download_file_name,
401
- mime="text/plain",
402
- )
403
- else:
404
- st.markdown("<hr>", unsafe_allow_html=True)
405
  st.markdown(
406
- f'<div class="tip"><strong>{ui_text[lang]["tip"]}</strong></div>',
407
  unsafe_allow_html=True,
408
- )
409
-
410
- except Exception as e:
411
- st.error(f"{ui_text[lang]['error']}{e}")
 
 
1
  import os
2
+ import re
3
+ import string
4
 
5
+ import matplotlib.cm as cm
 
6
  import streamlit as st
7
  from charset_normalizer import detect
8
  from transformers import (
 
12
  pipeline,
13
  )
14
 
15
+ st.set_page_config(page_title="German Legal NER", page_icon="⚖️", layout="wide")
16
  logging.set_verbosity(logging.ERROR)
17
 
 
 
18
  st.markdown(
19
  """
20
+ <style>
21
+ .block-container {
22
+ padding-top: 1rem;
23
+ padding-bottom: 5rem;
24
+ padding-left: 3rem;
25
+ padding-right: 3rem;
26
+ }
27
+
28
+ header, footer {visibility: hidden;}
29
+
30
+ .entity {
31
+ position: relative;
32
+ display: inline-block;
33
+ background-color: transparent;
34
+ font-weight: normal;
35
+ cursor: help;
36
+ }
37
+
38
+ .entity .tooltip {
39
+ visibility: hidden;
40
+ background-color: #333;
41
+ color: #fff;
42
+ text-align: center;
43
+ border-radius: 4px;
44
+ padding: 2px 6px;
45
+ position: absolute;
46
+ z-index: 1;
47
+ bottom: 125%;
48
+ left: 50%;
49
+ transform: translateX(-50%);
50
+ white-space: nowrap;
51
+ opacity: 0;
52
+ transition: opacity 0.05s;
53
+ font-size: 11px;
54
+ }
55
+
56
+ .entity:hover .tooltip {
57
+ visibility: visible;
58
+ opacity: 1;
59
+ }
60
+
61
+ .entity.marked {
62
+ background-color: rgba(255, 230, 0, 0.4);
63
+ }
64
+ </style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """,
66
  unsafe_allow_html=True,
67
  )
68
 
69
+ # Load model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  tkn = os.getenv("tkn")
71
  tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER", use_auth_token=tkn)
72
  model = AutoModelForTokenClassification.from_pretrained(
 
74
  )
75
  ner = pipeline("ner", model=model, tokenizer=tokenizer)
76
 
77
+ # Entity labels
78
+ entity_labels = {
79
  "AN": "Lawyer",
80
  "EUN": "European legal norm",
81
  "GRT": "Court",
 
96
  "VS": "Regulation",
97
  "VT": "Contract",
98
  }
 
99
 
100
 
101
+ # Fixed colors
102
+ def generate_fixed_colors(keys, alpha=0.25):
103
+ cmap = cm.get_cmap("tab20", len(keys))
104
+ rgba_colors = {}
105
+ for i, key in enumerate(keys):
106
+ r, g, b, _ = cmap(i)
107
+ rgba = f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {alpha})"
108
+ rgba_colors[key] = rgba
109
+ return rgba_colors
110
 
111
 
112
+ ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys()), alpha=0.30)
 
 
 
 
 
113
 
114
+ # UI
115
+ st.markdown("#### German Legal NER")
116
+ uploaded_file = st.file_uploader("Upload a .txt file", type="txt")
117
+ threshold = st.slider("Confidence threshold:", 0.0, 1.0, 0.8, 0.01)
118
+ st.markdown("---")
119
+
120
+
121
+ # Merge logic
122
+ def merge_entities(entities):
123
+ if not entities:
124
+ return []
125
+
126
+ ents = sorted(entities, key=lambda e: e["index"])
127
+ merged = [ents[0].copy()]
128
+ merged[0]["score_sum"] = ents[0]["score"]
129
+ merged[0]["count"] = 1
130
+
131
+ for ent in ents[1:]:
132
+ prev = merged[-1]
133
+ if ent["index"] == prev["index"] + 1:
134
+ tok = ent["word"]
135
+ if tok.startswith("##"):
136
+ prev["word"] += tok[2:]
137
+ else:
138
+ prev["word"] += " " + tok
139
+ prev["end"] = ent["end"]
140
+ prev["index"] = ent["index"]
141
+ prev["score_sum"] += ent["score"]
142
+ prev["count"] += 1
143
+ else:
144
+ prev["score"] = prev["score_sum"] / prev["count"]
145
+ del prev["score_sum"]
146
+ del prev["count"]
147
+ new_ent = ent.copy()
148
+ new_ent["score_sum"] = ent["score"]
149
+ new_ent["count"] = 1
150
+ merged.append(new_ent)
151
+
152
+ if "score_sum" in merged[-1]:
153
+ merged[-1]["score"] = merged[-1]["score_sum"] / merged[-1]["count"]
154
+ del merged[-1]["score_sum"]
155
+ del merged[-1]["count"]
156
+
157
+ final = []
158
+ for ent in merged:
159
+ w = ent["word"].strip()
160
+ w = re.sub(r"\s*\.\s*", ".", w)
161
+ w = re.sub(r"\s*,\s*", ", ", w)
162
+ w = re.sub(r"\s*/\s*", "/", w)
163
+ w = w.strip(string.whitespace + string.punctuation)
164
+ if len(w) > 1 and re.search(r"\w", w):
165
+ cleaned = ent.copy()
166
+ cleaned["word"] = w
167
+ final.append(cleaned)
168
+
169
+ return final
170
+
171
+
172
+ # HTML highlighting
173
+ def highlight_entities(line, merged_entities, threshold):
174
+ html = ""
175
  last_end = 0
176
+
177
+ for ent in merged_entities:
178
+ if ent["score"] < threshold:
179
+ continue
180
+
181
+ start, end = ent["start"], ent["end"]
182
+ label = ent["entity"].split("-")[-1]
183
+ label_desc = entity_labels.get(label, label)
184
+ color = ENTITY_COLORS.get(label, "#cccccc")
185
+
186
+ html += line[last_end:start]
187
+
188
+ highlight_style = f"background-color:{color}; font-weight:600;"
189
+ html += (
190
+ f'<span class="entity marked" style="{highlight_style}">'
191
+ f'{ent["word"]}<span class="tooltip">{label_desc}</span></span>'
192
  )
193
+
194
  last_end = end
195
 
196
+ html += line[last_end:]
197
+ return html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ if uploaded_file:
201
+ raw_bytes = uploaded_file.read()
202
+ encoding = detect(raw_bytes)["encoding"]
203
+ if encoding is None:
204
+ st.error("Could not detect file encoding.")
205
+ else:
206
+ text = raw_bytes.decode(encoding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
+ for line in text.splitlines():
209
+ if not line.strip():
210
+ st.write("")
211
+ continue
212
+
213
+ tokens = ner(line)
214
+ merged = merge_entities(tokens)
215
+ html_line = highlight_entities(line, merged, threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  st.markdown(
217
+ f'<div style="margin:0;padding:0;line-height:1.4;">{html_line}</div>',
218
  unsafe_allow_html=True,
219
+ )