Legal-NER-Demo / app.py
harshildarji's picture
Update app.py
1b901a6 verified
import re
import os
import warnings
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import streamlit as st
from charset_normalizer import detect
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
logging,
pipeline,
)
warnings.simplefilter(action="ignore", category=Warning)
logging.set_verbosity(logging.ERROR)
st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide")
st.markdown(
"""
<style>
body {
font-family: 'Poppins', sans-serif;
background-color: #f4f4f8;
}
.header {
background-color: rgba(220, 219, 219, 0.25);
color: #000;
padding: 5px 0;
text-align: center;
border-radius: 7px;
margin-bottom: 13px;
border-bottom: 2px solid #333;
}
.container {
background-color: #fff;
padding: 30px;
border-radius: 10px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 100%;
max-width: 1000px;
margin: 0 auto;
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}
.btn-primary {
background-color: #5477d1;
border: none;
transition: background-color 0.3s, transform 0.2s;
border-radius: 25px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
}
.btn-primary:hover {
background-color: #4c6cbe;
transform: translateY(-1px);
}
h2 {
font-weight: 600;
font-size: 24px;
margin-bottom: 20px;
}
label {
font-weight: 500;
}
.tip {
background-color: rgba(180, 47, 109, 0.25);
padding: 7px;
border-radius: 7px;
display: inline-block;
margin-top: 15px;
margin-bottom: 15px;
}
.sec {
background-color: rgba(220, 219, 219, 0.10);
padding: 7px;
border-radius: 5px;
display: inline-block;
margin-top: 15px;
margin-bottom: 15px;
}
.tooltip {
position: relative;
display: inline-block;
cursor: pointer;
}
.tooltip .tooltiptext {
visibility: hidden;
width: 120px;
background-color: #6c757d;
color: #fff;
text-align: center;
border-radius: 3px;
padding: 3px;
position: absolute;
z-index: 1;
bottom: 125%;
left: 50%;
margin-left: -60px;
opacity: 0;
transition: opacity 0.3s;
}
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;
}
.anonymized {
background-color: #ffcccb;
color: #000;
font-weight: bold;
border-radius: 3px;
padding: 2px 4px;
}
#language-container {
position: fixed;
top: 10px;
right: 10px;
z-index: 1000;
}
</style>
""",
unsafe_allow_html=True,
)
# UI text for English and German.
ui_text = {
"EN": {
"title": "Legal NER",
"upload": "Upload a .txt file",
"anonymize": "Anonymize",
"select_entities": "Entity types to anonymize:",
"download": "Download Anonymized Text",
"tip": "Tip: Hover over the colored words to see its class.",
"error": "An error occurred while processing the file: ",
},
"DE": {
"title": "Juristische NER",
"upload": "Lade eine .txt-Datei hoch",
"anonymize": "Anonymisieren",
"select_entities": "Entitätstypen zur Anonymisierung:",
"download": "Anonymisierten Text herunterladen",
"tip": "Tipp: Fahre mit der Maus über die farbigen Wörter, um deren Klasse zu sehen.",
"error": "Beim Verarbeiten der Datei ist ein Fehler aufgetreten: ",
},
}
col1, col2 = st.columns([4, 1])
with col2:
lang = st.radio(
"Language:",
options=["EN", "DE"],
horizontal=True,
label_visibility="hidden",
key="language_selector",
)
with col1:
st.title(ui_text[lang]["title"])
# Initialization for German Legal NER
tkn = os.getenv("tkn")
tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER", use_auth_token=tkn)
model = AutoModelForTokenClassification.from_pretrained(
"harshildarji/JuraNER", use_auth_token=tkn
)
ner = pipeline("ner", model=model, tokenizer=tokenizer)
# Define class labels for the model
classes = {
"AN": "Lawyer",
"EUN": "European legal norm",
"GRT": "Court",
"GS": "Law",
"INN": "Institution",
"LD": "Country",
"LDS": "Landscape",
"LIT": "Legal literature",
"MRK": "Brand",
"ORG": "Organization",
"PER": "Person",
"RR": "Judge",
"RS": "Court decision",
"ST": "City",
"STR": "Street",
"UN": "Company",
"VO": "Ordinance",
"VS": "Regulation",
"VT": "Contract",
}
ner_labels = list(classes.keys())
# Generate a list of colors for visualization
def generate_colors(num_colors):
cm = plt.get_cmap("tab20")
colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)]
return colors
# Color substrings based on NER results
def color_substrings(input_string, model_output):
colors = generate_colors(len(ner_labels))
label_to_color = {
label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
}
last_end = 0
html_output = ""
for entity in sorted(model_output, key=lambda x: x["start"]):
start, end, label = entity["start"], entity["end"], entity["label"]
html_output += input_string[last_end:start]
tooltip = classes.get(label, "")
html_output += (
f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">'
f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
)
last_end = end
html_output += input_string[last_end:]
return html_output
# Selectively anonymize entities
def anonymize_text(input_string, model_output, selected_entities=None):
merged_model_output = []
sorted_entities = sorted(model_output, key=lambda x: x["start"])
if sorted_entities:
current = sorted_entities[0]
for entity in sorted_entities[1:]:
if (
entity["label"] == current["label"]
and input_string[current["end"] : entity["start"]].strip() == ""
):
current["end"] = entity["end"]
current["word"] = input_string[current["start"] : current["end"]]
else:
merged_model_output.append(current)
current = entity
merged_model_output.append(current)
else:
merged_model_output = sorted_entities
anonymized_text = ""
last_end = 0
colors = generate_colors(len(ner_labels))
label_to_color = {
label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
}
for entity in merged_model_output:
start, end, label = entity["start"], entity["end"], entity["label"]
anonymized_text += input_string[last_end:start]
if selected_entities is None or label in selected_entities:
anonymized_text += (
f'<span class="anonymized">[{classes.get(label, label)}]</span>'
)
else:
tooltip = classes.get(label, "")
anonymized_text += (
f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">'
f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
)
last_end = end
anonymized_text += input_string[last_end:]
return anonymized_text
def merge_entities(ner_results):
merged_entities = []
current_entity = None
for token in ner_results:
tag = token["entity"]
entity_type = tag.split("-")[-1] if "-" in tag else tag
token_start, token_end = token["start"], token["end"]
token_word = token["word"].replace("##", "") # Remove subword prefixes
if (
tag.startswith("B-")
or current_entity is None
or current_entity["label"] != entity_type
):
if current_entity:
merged_entities.append(current_entity)
current_entity = {
"start": token_start,
"end": token_end,
"label": entity_type,
"word": token_word,
}
elif (
tag.startswith("I-")
and current_entity
and current_entity["label"] == entity_type
):
current_entity["end"] = token_end
current_entity["word"] += token_word
else:
if (
current_entity
and token_start == current_entity["end"]
and current_entity["label"] == entity_type
):
current_entity["end"] = token_end
current_entity["word"] += token_word
else:
if current_entity:
merged_entities.append(current_entity)
current_entity = {
"start": token_start,
"end": token_end,
"label": entity_type,
"word": token_word,
}
if current_entity:
merged_entities.append(current_entity)
return merged_entities
uploaded_file = st.file_uploader(ui_text[lang]["upload"], type="txt")
if uploaded_file is not None:
try:
raw_content = uploaded_file.read()
detected = detect(raw_content)
encoding = detected["encoding"]
if encoding is None:
raise ValueError("Unable to detect file encoding.")
lines = raw_content.decode(encoding).splitlines()
line_results = []
for line in lines:
if line.strip():
results = ner(line)
merged_results = merge_entities(results)
line_results.append(merged_results)
else:
line_results.append([])
anonymize_mode = st.checkbox(ui_text[lang]["anonymize"])
selected_entities = None
if anonymize_mode:
detected_entity_tags = set()
for merged_results in line_results:
for entity in merged_results:
detected_entity_tags.add(entity["label"])
inverse_classes = {v: k for k, v in classes.items()}
detected_options = sorted([classes[tag] for tag in detected_entity_tags])
selected_options = st.multiselect(
ui_text[lang]["select_entities"],
options=detected_options,
default=detected_options,
)
selected_entities = [
inverse_classes[options] for options in selected_options
]
st.markdown(
"<hr style='margin-top: 10px; margin-bottom: 20px;'>",
unsafe_allow_html=True,
)
anonymized_lines = []
displayed_lines = []
for line, merged_results in zip(lines, line_results):
if line.strip():
if anonymize_mode:
anonymized_text = anonymize_text(
line, merged_results, selected_entities=selected_entities
)
displayed_lines.append(anonymized_text)
plain_text = re.sub(r"<.*?>", "", anonymized_text)
anonymized_lines.append(plain_text.strip())
else:
colored_html = color_substrings(line, merged_results)
st.markdown(f"{colored_html}", unsafe_allow_html=True)
else:
# displayed_lines.append("<br>")
anonymized_lines.append("")
if anonymize_mode:
original_file_name = uploaded_file.name
download_file_name = f"Anon_{original_file_name}"
anonymized_content = "\n".join(anonymized_lines)
for displayed_line in displayed_lines:
st.markdown(f"{displayed_line}", unsafe_allow_html=True)
st.markdown("<hr>", unsafe_allow_html=True)
st.download_button(
label=ui_text[lang]["download"],
data=anonymized_content,
file_name=download_file_name,
mime="text/plain",
)
else:
st.markdown("<hr>", unsafe_allow_html=True)
st.markdown(
f'<div class="tip"><strong>{ui_text[lang]["tip"]}</strong></div>',
unsafe_allow_html=True,
)
except Exception as e:
st.error(f"{ui_text[lang]['error']}{e}")