Spaces:
Sleeping
Sleeping
import re | |
import os | |
import warnings | |
import matplotlib.colors as mcolors | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
from charset_normalizer import detect | |
from transformers import ( | |
AutoModelForTokenClassification, | |
AutoTokenizer, | |
logging, | |
pipeline, | |
) | |
warnings.simplefilter(action="ignore", category=Warning) | |
logging.set_verbosity(logging.ERROR) | |
st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide") | |
st.markdown( | |
""" | |
<style> | |
body { | |
font-family: 'Poppins', sans-serif; | |
background-color: #f4f4f8; | |
} | |
.header { | |
background-color: rgba(220, 219, 219, 0.25); | |
color: #000; | |
padding: 5px 0; | |
text-align: center; | |
border-radius: 7px; | |
margin-bottom: 13px; | |
border-bottom: 2px solid #333; | |
} | |
.container { | |
background-color: #fff; | |
padding: 30px; | |
border-radius: 10px; | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); | |
width: 100%; | |
max-width: 1000px; | |
margin: 0 auto; | |
position: absolute; | |
top: 50%; | |
left: 50%; | |
transform: translate(-50%, -50%); | |
} | |
.btn-primary { | |
background-color: #5477d1; | |
border: none; | |
transition: background-color 0.3s, transform 0.2s; | |
border-radius: 25px; | |
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08); | |
} | |
.btn-primary:hover { | |
background-color: #4c6cbe; | |
transform: translateY(-1px); | |
} | |
h2 { | |
font-weight: 600; | |
font-size: 24px; | |
margin-bottom: 20px; | |
} | |
label { | |
font-weight: 500; | |
} | |
.tip { | |
background-color: rgba(180, 47, 109, 0.25); | |
padding: 7px; | |
border-radius: 7px; | |
display: inline-block; | |
margin-top: 15px; | |
margin-bottom: 15px; | |
} | |
.sec { | |
background-color: rgba(220, 219, 219, 0.10); | |
padding: 7px; | |
border-radius: 5px; | |
display: inline-block; | |
margin-top: 15px; | |
margin-bottom: 15px; | |
} | |
.tooltip { | |
position: relative; | |
display: inline-block; | |
cursor: pointer; | |
} | |
.tooltip .tooltiptext { | |
visibility: hidden; | |
width: 120px; | |
background-color: #6c757d; | |
color: #fff; | |
text-align: center; | |
border-radius: 3px; | |
padding: 3px; | |
position: absolute; | |
z-index: 1; | |
bottom: 125%; | |
left: 50%; | |
margin-left: -60px; | |
opacity: 0; | |
transition: opacity 0.3s; | |
} | |
.tooltip:hover .tooltiptext { | |
visibility: visible; | |
opacity: 1; | |
} | |
.anonymized { | |
background-color: #ffcccb; | |
color: #000; | |
font-weight: bold; | |
border-radius: 3px; | |
padding: 2px 4px; | |
} | |
#language-container { | |
position: fixed; | |
top: 10px; | |
right: 10px; | |
z-index: 1000; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# UI text for English and German. | |
ui_text = { | |
"EN": { | |
"title": "Legal NER", | |
"upload": "Upload a .txt file", | |
"anonymize": "Anonymize", | |
"select_entities": "Entity types to anonymize:", | |
"download": "Download Anonymized Text", | |
"tip": "Tip: Hover over the colored words to see its class.", | |
"error": "An error occurred while processing the file: ", | |
}, | |
"DE": { | |
"title": "Juristische NER", | |
"upload": "Lade eine .txt-Datei hoch", | |
"anonymize": "Anonymisieren", | |
"select_entities": "Entitätstypen zur Anonymisierung:", | |
"download": "Anonymisierten Text herunterladen", | |
"tip": "Tipp: Fahre mit der Maus über die farbigen Wörter, um deren Klasse zu sehen.", | |
"error": "Beim Verarbeiten der Datei ist ein Fehler aufgetreten: ", | |
}, | |
} | |
col1, col2 = st.columns([4, 1]) | |
with col2: | |
lang = st.radio( | |
"Language:", | |
options=["EN", "DE"], | |
horizontal=True, | |
label_visibility="hidden", | |
key="language_selector", | |
) | |
with col1: | |
st.title(ui_text[lang]["title"]) | |
# Initialization for German Legal NER | |
tkn = os.getenv("tkn") | |
tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER", use_auth_token=tkn) | |
model = AutoModelForTokenClassification.from_pretrained( | |
"harshildarji/JuraNER", use_auth_token=tkn | |
) | |
ner = pipeline("ner", model=model, tokenizer=tokenizer) | |
# Define class labels for the model | |
classes = { | |
"AN": "Lawyer", | |
"EUN": "European legal norm", | |
"GRT": "Court", | |
"GS": "Law", | |
"INN": "Institution", | |
"LD": "Country", | |
"LDS": "Landscape", | |
"LIT": "Legal literature", | |
"MRK": "Brand", | |
"ORG": "Organization", | |
"PER": "Person", | |
"RR": "Judge", | |
"RS": "Court decision", | |
"ST": "City", | |
"STR": "Street", | |
"UN": "Company", | |
"VO": "Ordinance", | |
"VS": "Regulation", | |
"VT": "Contract", | |
} | |
ner_labels = list(classes.keys()) | |
# Generate a list of colors for visualization | |
def generate_colors(num_colors): | |
cm = plt.get_cmap("tab20") | |
colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)] | |
return colors | |
# Color substrings based on NER results | |
def color_substrings(input_string, model_output): | |
colors = generate_colors(len(ner_labels)) | |
label_to_color = { | |
label: colors[i % len(colors)] for i, label in enumerate(ner_labels) | |
} | |
last_end = 0 | |
html_output = "" | |
for entity in sorted(model_output, key=lambda x: x["start"]): | |
start, end, label = entity["start"], entity["end"], entity["label"] | |
html_output += input_string[last_end:start] | |
tooltip = classes.get(label, "") | |
html_output += ( | |
f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">' | |
f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>' | |
) | |
last_end = end | |
html_output += input_string[last_end:] | |
return html_output | |
# Selectively anonymize entities | |
def anonymize_text(input_string, model_output, selected_entities=None): | |
merged_model_output = [] | |
sorted_entities = sorted(model_output, key=lambda x: x["start"]) | |
if sorted_entities: | |
current = sorted_entities[0] | |
for entity in sorted_entities[1:]: | |
if ( | |
entity["label"] == current["label"] | |
and input_string[current["end"] : entity["start"]].strip() == "" | |
): | |
current["end"] = entity["end"] | |
current["word"] = input_string[current["start"] : current["end"]] | |
else: | |
merged_model_output.append(current) | |
current = entity | |
merged_model_output.append(current) | |
else: | |
merged_model_output = sorted_entities | |
anonymized_text = "" | |
last_end = 0 | |
colors = generate_colors(len(ner_labels)) | |
label_to_color = { | |
label: colors[i % len(colors)] for i, label in enumerate(ner_labels) | |
} | |
for entity in merged_model_output: | |
start, end, label = entity["start"], entity["end"], entity["label"] | |
anonymized_text += input_string[last_end:start] | |
if selected_entities is None or label in selected_entities: | |
anonymized_text += ( | |
f'<span class="anonymized">[{classes.get(label, label)}]</span>' | |
) | |
else: | |
tooltip = classes.get(label, "") | |
anonymized_text += ( | |
f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">' | |
f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>' | |
) | |
last_end = end | |
anonymized_text += input_string[last_end:] | |
return anonymized_text | |
def merge_entities(ner_results): | |
merged_entities = [] | |
current_entity = None | |
for token in ner_results: | |
tag = token["entity"] | |
entity_type = tag.split("-")[-1] if "-" in tag else tag | |
token_start, token_end = token["start"], token["end"] | |
token_word = token["word"].replace("##", "") # Remove subword prefixes | |
if ( | |
tag.startswith("B-") | |
or current_entity is None | |
or current_entity["label"] != entity_type | |
): | |
if current_entity: | |
merged_entities.append(current_entity) | |
current_entity = { | |
"start": token_start, | |
"end": token_end, | |
"label": entity_type, | |
"word": token_word, | |
} | |
elif ( | |
tag.startswith("I-") | |
and current_entity | |
and current_entity["label"] == entity_type | |
): | |
current_entity["end"] = token_end | |
current_entity["word"] += token_word | |
else: | |
if ( | |
current_entity | |
and token_start == current_entity["end"] | |
and current_entity["label"] == entity_type | |
): | |
current_entity["end"] = token_end | |
current_entity["word"] += token_word | |
else: | |
if current_entity: | |
merged_entities.append(current_entity) | |
current_entity = { | |
"start": token_start, | |
"end": token_end, | |
"label": entity_type, | |
"word": token_word, | |
} | |
if current_entity: | |
merged_entities.append(current_entity) | |
return merged_entities | |
uploaded_file = st.file_uploader(ui_text[lang]["upload"], type="txt") | |
if uploaded_file is not None: | |
try: | |
raw_content = uploaded_file.read() | |
detected = detect(raw_content) | |
encoding = detected["encoding"] | |
if encoding is None: | |
raise ValueError("Unable to detect file encoding.") | |
lines = raw_content.decode(encoding).splitlines() | |
line_results = [] | |
for line in lines: | |
if line.strip(): | |
results = ner(line) | |
merged_results = merge_entities(results) | |
line_results.append(merged_results) | |
else: | |
line_results.append([]) | |
anonymize_mode = st.checkbox(ui_text[lang]["anonymize"]) | |
selected_entities = None | |
if anonymize_mode: | |
detected_entity_tags = set() | |
for merged_results in line_results: | |
for entity in merged_results: | |
detected_entity_tags.add(entity["label"]) | |
inverse_classes = {v: k for k, v in classes.items()} | |
detected_options = sorted([classes[tag] for tag in detected_entity_tags]) | |
selected_options = st.multiselect( | |
ui_text[lang]["select_entities"], | |
options=detected_options, | |
default=detected_options, | |
) | |
selected_entities = [ | |
inverse_classes[options] for options in selected_options | |
] | |
st.markdown( | |
"<hr style='margin-top: 10px; margin-bottom: 20px;'>", | |
unsafe_allow_html=True, | |
) | |
anonymized_lines = [] | |
displayed_lines = [] | |
for line, merged_results in zip(lines, line_results): | |
if line.strip(): | |
if anonymize_mode: | |
anonymized_text = anonymize_text( | |
line, merged_results, selected_entities=selected_entities | |
) | |
displayed_lines.append(anonymized_text) | |
plain_text = re.sub(r"<.*?>", "", anonymized_text) | |
anonymized_lines.append(plain_text.strip()) | |
else: | |
colored_html = color_substrings(line, merged_results) | |
st.markdown(f"{colored_html}", unsafe_allow_html=True) | |
else: | |
# displayed_lines.append("<br>") | |
anonymized_lines.append("") | |
if anonymize_mode: | |
original_file_name = uploaded_file.name | |
download_file_name = f"Anon_{original_file_name}" | |
anonymized_content = "\n".join(anonymized_lines) | |
for displayed_line in displayed_lines: | |
st.markdown(f"{displayed_line}", unsafe_allow_html=True) | |
st.markdown("<hr>", unsafe_allow_html=True) | |
st.download_button( | |
label=ui_text[lang]["download"], | |
data=anonymized_content, | |
file_name=download_file_name, | |
mime="text/plain", | |
) | |
else: | |
st.markdown("<hr>", unsafe_allow_html=True) | |
st.markdown( | |
f'<div class="tip"><strong>{ui_text[lang]["tip"]}</strong></div>', | |
unsafe_allow_html=True, | |
) | |
except Exception as e: | |
st.error(f"{ui_text[lang]['error']}{e}") |