Legal-NER-Demo / app.py
harshildarji's picture
Update app.py
7c01ff6 verified
raw
history blame
9.99 kB
import re
import os
import warnings
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import streamlit as st
from charset_normalizer import detect
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
logging,
pipeline,
)
warnings.simplefilter(action="ignore", category=Warning)
logging.set_verbosity(logging.ERROR)
st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide")
st.markdown(
"""
<style>
body {
font-family: 'Poppins', sans-serif;
background-color: #f4f4f8;
}
.header {
background-color: rgba(220, 219, 219, 0.25);
color: #000;
padding: 5px 0;
text-align: center;
border-radius: 7px;
margin-bottom: 13px;
border-bottom: 2px solid #333;
}
.container {
background-color: #fff;
padding: 30px;
border-radius: 10px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 100%;
max-width: 1000px;
margin: 0 auto;
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}
.btn-primary {
background-color: #5477d1;
border: none;
transition: background-color 0.3s, transform 0.2s;
border-radius: 25px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
}
.btn-primary:hover {
background-color: #4c6cbe;
transform: translateY(-1px);
}
h2 {
font-weight: 600;
font-size: 24px;
margin-bottom: 20px;
}
label {
font-weight: 500;
}
.tip {
background-color: rgba(180, 47, 109, 0.25);
padding: 7px;
border-radius: 7px;
display: inline-block;
margin-top: 15px;
margin-bottom: 15px;
}
.sec {
background-color: rgba(220, 219, 219, 0.10);
padding: 7px;
border-radius: 5px;
display: inline-block;
margin-top: 15px;
margin-bottom: 15px;
}
.tooltip {
position: relative;
display: inline-block;
cursor: pointer;
}
.tooltip .tooltiptext {
visibility: hidden;
width: 120px;
background-color: #6c757d;
color: #fff;
text-align: center;
border-radius: 3px;
padding: 3px;
position: absolute;
z-index: 1;
bottom: 125%;
left: 50%;
margin-left: -60px;
opacity: 0;
transition: opacity 0.3s;
}
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;
}
.anonymized {
background-color: #ffcccb;
color: #000;
font-weight: bold;
border-radius: 3px;
padding: 2px 4px;
}
</style>
""",
unsafe_allow_html=True,
)
# Initialization for German Legal NER
tkn = os.getenv("tkn")
tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraBERT", use_auth_token=tkn)
model = AutoModelForTokenClassification.from_pretrained(
"harshildarji/JuraBERT", use_auth_token=tkn
)
ner = pipeline("ner", model=model, tokenizer=tokenizer)
# Define class labels for the model
classes = {
"AN": "Lawyer",
"EUN": "European legal norm",
"GRT": "Court",
"GS": "Law",
"INN": "Institution",
"LD": "Country",
"LDS": "Landscape",
"LIT": "Legal literature",
"MRK": "Brand",
"ORG": "Organization",
"PER": "Person",
"RR": "Judge",
"RS": "Court decision",
"ST": "City",
"STR": "Street",
"UN": "Company",
"VO": "Ordinance",
"VS": "Regulation",
"VT": "Contract",
}
ner_labels = list(classes.keys())
# Function to generate a list of colors for visualization
def generate_colors(num_colors):
cm = plt.get_cmap("tab20")
colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)]
return colors
# Function to color substrings based on NER results
def color_substrings(input_string, model_output):
colors = generate_colors(len(ner_labels))
label_to_color = {
label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
}
last_end = 0
html_output = ""
for entity in sorted(model_output, key=lambda x: x["start"]):
start, end, label = entity["start"], entity["end"], entity["label"]
html_output += input_string[last_end:start]
tooltip = classes.get(label, "")
html_output += f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
last_end = end
html_output += input_string[last_end:]
return html_output
# Function to anonymize entities
def anonymize_text(input_string, model_output):
anonymized_text = ""
last_end = 0
for entity in sorted(model_output, key=lambda x: x["start"]):
start, end, label = entity["start"], entity["end"], entity["label"]
anonymized_text += input_string[last_end:start]
anonymized_text += (
f'<span class="anonymized">[{classes.get(label, label)}]</span>'
)
last_end = end
anonymized_text += input_string[last_end:]
return anonymized_text
def merge_entities(ner_results):
merged_entities = []
current_entity = None
for token in ner_results:
tag = token["entity"]
entity_type = tag.split("-")[-1] if "-" in tag else tag
token_start, token_end = token["start"], token["end"]
token_word = token["word"].replace("##", "") # Remove subword prefixes
# Start a new entity if necessary
if (
tag.startswith("B-")
or current_entity is None
or current_entity["label"] != entity_type
):
if current_entity:
merged_entities.append(current_entity)
current_entity = {
"start": token_start,
"end": token_end,
"label": entity_type,
"word": token_word,
}
elif (
tag.startswith("I-")
and current_entity
and current_entity["label"] == entity_type
):
# Extend the current entity
current_entity["end"] = token_end
current_entity["word"] += token_word
else:
# Handle misclassifications or gaps in tokens
if (
current_entity
and token_start == current_entity["end"]
and current_entity["label"] == entity_type
):
current_entity["end"] = token_end
current_entity["word"] += token_word
else:
# Treat it as a new entity if the above conditions aren't met
if current_entity:
merged_entities.append(current_entity)
current_entity = {
"start": token_start,
"end": token_end,
"label": entity_type,
"word": token_word,
}
# Append the last entity
if current_entity:
merged_entities.append(current_entity)
return merged_entities
st.title("Legal NER")
st.markdown("<hr>", unsafe_allow_html=True)
uploaded_file = st.file_uploader("Upload a .txt file", type="txt")
if uploaded_file is not None:
try:
raw_content = uploaded_file.read()
detected = detect(raw_content)
encoding = detected["encoding"]
if encoding is None:
raise ValueError("Unable to detect file encoding.")
lines = raw_content.decode(encoding).splitlines()
anonymize_mode = st.checkbox("Anonymize")
st.markdown(
"<hr style='margin-top: 10px; margin-bottom: 20px;'>",
unsafe_allow_html=True,
)
anonymized_lines = []
displayed_lines = []
for line_number, line in enumerate(lines, start=1):
if line.strip():
results = ner(line)
merged_results = merge_entities(results)
if anonymize_mode:
anonymized_text = anonymize_text(line, merged_results)
displayed_lines.append(anonymized_text)
plain_text = re.sub(r"<.*?>", "", anonymized_text)
anonymized_lines.append(plain_text.strip())
else:
colored_html = color_substrings(line, merged_results)
st.markdown(f"{colored_html}", unsafe_allow_html=True)
else:
displayed_lines.append("<br>")
anonymized_lines.append("")
if anonymize_mode:
original_file_name = uploaded_file.name
download_file_name = f"Anon_{original_file_name}"
anonymized_content = "\n".join(anonymized_lines)
for displayed_line in displayed_lines:
st.markdown(f"{displayed_line}", unsafe_allow_html=True)
st.markdown("<hr>", unsafe_allow_html=True)
st.download_button(
label="Download Anonymized Text",
data=anonymized_content,
file_name=download_file_name,
mime="text/plain",
)
if not anonymize_mode:
st.markdown(
'<div class="tip"><strong>Tip:</strong> Hover over the colored words to see its class.</div>',
unsafe_allow_html=True,
)
except Exception as e:
st.error(f"An error occurred while processing the file: {e}")