Legal-NER-Demo / app.py
harshildarji's picture
merge adjacent subtokens during anonymization
85ed894 verified
raw
history blame
13.2 kB
import re
import os
import warnings
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import streamlit as st
from charset_normalizer import detect
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
logging,
pipeline,
)
warnings.simplefilter(action="ignore", category=Warning)
logging.set_verbosity(logging.ERROR)
st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide")
st.markdown(
"""
<style>
body {
font-family: 'Poppins', sans-serif;
background-color: #f4f4f8;
}
.header {
background-color: rgba(220, 219, 219, 0.25);
color: #000;
padding: 5px 0;
text-align: center;
border-radius: 7px;
margin-bottom: 13px;
border-bottom: 2px solid #333;
}
.container {
background-color: #fff;
padding: 30px;
border-radius: 10px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 100%;
max-width: 1000px;
margin: 0 auto;
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}
.btn-primary {
background-color: #5477d1;
border: none;
transition: background-color 0.3s, transform 0.2s;
border-radius: 25px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
}
.btn-primary:hover {
background-color: #4c6cbe;
transform: translateY(-1px);
}
h2 {
font-weight: 600;
font-size: 24px;
margin-bottom: 20px;
}
label {
font-weight: 500;
}
.tip {
background-color: rgba(180, 47, 109, 0.25);
padding: 7px;
border-radius: 7px;
display: inline-block;
margin-top: 15px;
margin-bottom: 15px;
}
.sec {
background-color: rgba(220, 219, 219, 0.10);
padding: 7px;
border-radius: 5px;
display: inline-block;
margin-top: 15px;
margin-bottom: 15px;
}
.tooltip {
position: relative;
display: inline-block;
cursor: pointer;
}
.tooltip .tooltiptext {
visibility: hidden;
width: 120px;
background-color: #6c757d;
color: #fff;
text-align: center;
border-radius: 3px;
padding: 3px;
position: absolute;
z-index: 1;
bottom: 125%;
left: 50%;
margin-left: -60px;
opacity: 0;
transition: opacity 0.3s;
}
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;
}
.anonymized {
background-color: #ffcccb;
color: #000;
font-weight: bold;
border-radius: 3px;
padding: 2px 4px;
}
#language-container {
position: fixed;
top: 10px;
right: 10px;
z-index: 1000;
}
</style>
""",
unsafe_allow_html=True,
)
# UI text for English and German.
ui_text = {
"EN": {
"title": "Legal NER",
"upload": "Upload a .txt file",
"anonymize": "Anonymize",
"select_entities": "Entity types to anonymize:",
"download": "Download Anonymized Text",
"tip": "Tip: Hover over the colored words to see its class.",
"error": "An error occurred while processing the file: ",
},
"DE": {
"title": "Juristische NER",
"upload": "Lade eine .txt-Datei hoch",
"anonymize": "Anonymisieren",
"select_entities": "Entitätstypen zur Anonymisierung:",
"download": "Anonymisierten Text herunterladen",
"tip": "Tipp: Fahre mit der Maus über die farbigen Wörter, um deren Klasse zu sehen.",
"error": "Beim Verarbeiten der Datei ist ein Fehler aufgetreten: ",
},
}
col1, col2 = st.columns([4, 1])
with col2:
lang = st.radio(
"Language:",
options=["EN", "DE"],
horizontal=True,
label_visibility="hidden",
key="language_selector",
)
with col1:
st.title(ui_text[lang]["title"])
# Initialization for German Legal NER
tkn = os.getenv("tkn")
tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraBERT", use_auth_token=tkn)
model = AutoModelForTokenClassification.from_pretrained(
"harshildarji/JuraBERT", use_auth_token=tkn
)
ner = pipeline("ner", model=model, tokenizer=tokenizer)
# Define class labels for the model
classes = {
"AN": "Lawyer",
"EUN": "European legal norm",
"GRT": "Court",
"GS": "Law",
"INN": "Institution",
"LD": "Country",
"LDS": "Landscape",
"LIT": "Legal literature",
"MRK": "Brand",
"ORG": "Organization",
"PER": "Person",
"RR": "Judge",
"RS": "Court decision",
"ST": "City",
"STR": "Street",
"UN": "Company",
"VO": "Ordinance",
"VS": "Regulation",
"VT": "Contract",
}
ner_labels = list(classes.keys())
# Generate a list of colors for visualization
def generate_colors(num_colors):
cm = plt.get_cmap("tab20")
colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)]
return colors
# Color substrings based on NER results
def color_substrings(input_string, model_output):
colors = generate_colors(len(ner_labels))
label_to_color = {
label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
}
last_end = 0
html_output = ""
for entity in sorted(model_output, key=lambda x: x["start"]):
start, end, label = entity["start"], entity["end"], entity["label"]
html_output += input_string[last_end:start]
tooltip = classes.get(label, "")
html_output += (
f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">'
f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
)
last_end = end
html_output += input_string[last_end:]
return html_output
# Selectively anonymize entities
def anonymize_text(input_string, model_output, selected_entities=None):
merged_model_output = []
sorted_entities = sorted(model_output, key=lambda x: x["start"])
if sorted_entities:
current = sorted_entities[0]
for entity in sorted_entities[1:]:
if (
entity["label"] == current["label"]
and input_string[current["end"] : entity["start"]].strip() == ""
):
current["end"] = entity["end"]
current["word"] = input_string[current["start"] : current["end"]]
else:
merged_model_output.append(current)
current = entity
merged_model_output.append(current)
else:
merged_model_output = sorted_entities
anonymized_text = ""
last_end = 0
colors = generate_colors(len(ner_labels))
label_to_color = {
label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
}
for entity in merged_model_output:
start, end, label = entity["start"], entity["end"], entity["label"]
anonymized_text += input_string[last_end:start]
if selected_entities is None or label in selected_entities:
anonymized_text += (
f'<span class="anonymized">[{classes.get(label, label)}]</span>'
)
else:
tooltip = classes.get(label, "")
anonymized_text += (
f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">'
f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
)
last_end = end
anonymized_text += input_string[last_end:]
return anonymized_text
def merge_entities(ner_results):
merged_entities = []
current_entity = None
for token in ner_results:
tag = token["entity"]
entity_type = tag.split("-")[-1] if "-" in tag else tag
token_start, token_end = token["start"], token["end"]
token_word = token["word"].replace("##", "") # Remove subword prefixes
if (
tag.startswith("B-")
or current_entity is None
or current_entity["label"] != entity_type
):
if current_entity:
merged_entities.append(current_entity)
current_entity = {
"start": token_start,
"end": token_end,
"label": entity_type,
"word": token_word,
}
elif (
tag.startswith("I-")
and current_entity
and current_entity["label"] == entity_type
):
current_entity["end"] = token_end
current_entity["word"] += token_word
else:
if (
current_entity
and token_start == current_entity["end"]
and current_entity["label"] == entity_type
):
current_entity["end"] = token_end
current_entity["word"] += token_word
else:
if current_entity:
merged_entities.append(current_entity)
current_entity = {
"start": token_start,
"end": token_end,
"label": entity_type,
"word": token_word,
}
if current_entity:
merged_entities.append(current_entity)
return merged_entities
uploaded_file = st.file_uploader(ui_text[lang]["upload"], type="txt")
if uploaded_file is not None:
try:
raw_content = uploaded_file.read()
detected = detect(raw_content)
encoding = detected["encoding"]
if encoding is None:
raise ValueError("Unable to detect file encoding.")
lines = raw_content.decode(encoding).splitlines()
line_results = []
for line in lines:
if line.strip():
results = ner(line)
merged_results = merge_entities(results)
line_results.append(merged_results)
else:
line_results.append([])
anonymize_mode = st.checkbox(ui_text[lang]["anonymize"])
selected_entities = None
if anonymize_mode:
detected_entity_tags = set()
for merged_results in line_results:
for entity in merged_results:
detected_entity_tags.add(entity["label"])
inverse_classes = {v: k for k, v in classes.items()}
detected_options = sorted([classes[tag] for tag in detected_entity_tags])
selected_options = st.multiselect(
ui_text[lang]["select_entities"],
options=detected_options,
default=detected_options,
)
selected_entities = [
inverse_classes[options] for options in selected_options
]
st.markdown(
"<hr style='margin-top: 10px; margin-bottom: 20px;'>",
unsafe_allow_html=True,
)
anonymized_lines = []
displayed_lines = []
for line, merged_results in zip(lines, line_results):
if line.strip():
if anonymize_mode:
anonymized_text = anonymize_text(
line, merged_results, selected_entities=selected_entities
)
displayed_lines.append(anonymized_text)
plain_text = re.sub(r"<.*?>", "", anonymized_text)
anonymized_lines.append(plain_text.strip())
else:
colored_html = color_substrings(line, merged_results)
st.markdown(f"{colored_html}", unsafe_allow_html=True)
else:
# displayed_lines.append("<br>")
anonymized_lines.append("")
if anonymize_mode:
original_file_name = uploaded_file.name
download_file_name = f"Anon_{original_file_name}"
anonymized_content = "\n".join(anonymized_lines)
for displayed_line in displayed_lines:
st.markdown(f"{displayed_line}", unsafe_allow_html=True)
st.markdown("<hr>", unsafe_allow_html=True)
st.download_button(
label=ui_text[lang]["download"],
data=anonymized_content,
file_name=download_file_name,
mime="text/plain",
)
else:
st.markdown("<hr>", unsafe_allow_html=True)
st.markdown(
f'<div class="tip"><strong>{ui_text[lang]["tip"]}</strong></div>',
unsafe_allow_html=True,
)
except Exception as e:
st.error(f"{ui_text[lang]['error']}{e}")