Spaces:
Sleeping
Sleeping
import re | |
import os | |
import warnings | |
import matplotlib.colors as mcolors | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
from charset_normalizer import detect | |
from transformers import ( | |
AutoModelForTokenClassification, | |
AutoTokenizer, | |
logging, | |
pipeline, | |
) | |
warnings.simplefilter(action="ignore", category=Warning) | |
logging.set_verbosity(logging.ERROR) | |
st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide") | |
st.markdown( | |
""" | |
<style> | |
body { | |
font-family: 'Poppins', sans-serif; | |
background-color: #f4f4f8; | |
} | |
.header { | |
background-color: rgba(220, 219, 219, 0.25); | |
color: #000; | |
padding: 5px 0; | |
text-align: center; | |
border-radius: 7px; | |
margin-bottom: 13px; | |
border-bottom: 2px solid #333; | |
} | |
.container { | |
background-color: #fff; | |
padding: 30px; | |
border-radius: 10px; | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); | |
width: 100%; | |
max-width: 1000px; | |
margin: 0 auto; | |
position: absolute; | |
top: 50%; | |
left: 50%; | |
transform: translate(-50%, -50%); | |
} | |
.btn-primary { | |
background-color: #5477d1; | |
border: none; | |
transition: background-color 0.3s, transform 0.2s; | |
border-radius: 25px; | |
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08); | |
} | |
.btn-primary:hover { | |
background-color: #4c6cbe; | |
transform: translateY(-1px); | |
} | |
h2 { | |
font-weight: 600; | |
font-size: 24px; | |
margin-bottom: 20px; | |
} | |
label { | |
font-weight: 500; | |
} | |
.tip { | |
background-color: rgba(180, 47, 109, 0.25); | |
padding: 7px; | |
border-radius: 7px; | |
display: inline-block; | |
margin-top: 15px; | |
margin-bottom: 15px; | |
} | |
.sec { | |
background-color: rgba(220, 219, 219, 0.10); | |
padding: 7px; | |
border-radius: 5px; | |
display: inline-block; | |
margin-top: 15px; | |
margin-bottom: 15px; | |
} | |
.tooltip { | |
position: relative; | |
display: inline-block; | |
cursor: pointer; | |
} | |
.tooltip .tooltiptext { | |
visibility: hidden; | |
width: 120px; | |
background-color: #6c757d; | |
color: #fff; | |
text-align: center; | |
border-radius: 3px; | |
padding: 3px; | |
position: absolute; | |
z-index: 1; | |
bottom: 125%; | |
left: 50%; | |
margin-left: -60px; | |
opacity: 0; | |
transition: opacity 0.3s; | |
} | |
.tooltip:hover .tooltiptext { | |
visibility: visible; | |
opacity: 1; | |
} | |
.anonymized { | |
background-color: #ffcccb; | |
color: #000; | |
font-weight: bold; | |
border-radius: 3px; | |
padding: 2px 4px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Initialization for German Legal NER | |
tkn = os.getenv("tkn") | |
tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraBERT", use_auth_token=tkn) | |
model = AutoModelForTokenClassification.from_pretrained( | |
"harshildarji/JuraBERT", use_auth_token=tkn | |
) | |
ner = pipeline("ner", model=model, tokenizer=tokenizer) | |
# Define class labels for the model | |
classes = { | |
"AN": "Lawyer", | |
"EUN": "European legal norm", | |
"GRT": "Court", | |
"GS": "Law", | |
"INN": "Institution", | |
"LD": "Country", | |
"LDS": "Landscape", | |
"LIT": "Legal literature", | |
"MRK": "Brand", | |
"ORG": "Organization", | |
"PER": "Person", | |
"RR": "Judge", | |
"RS": "Court decision", | |
"ST": "City", | |
"STR": "Street", | |
"UN": "Company", | |
"VO": "Ordinance", | |
"VS": "Regulation", | |
"VT": "Contract", | |
} | |
ner_labels = list(classes.keys()) | |
# Function to generate a list of colors for visualization | |
def generate_colors(num_colors): | |
cm = plt.get_cmap("tab20") | |
colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)] | |
return colors | |
# Function to color substrings based on NER results | |
def color_substrings(input_string, model_output): | |
colors = generate_colors(len(ner_labels)) | |
label_to_color = { | |
label: colors[i % len(colors)] for i, label in enumerate(ner_labels) | |
} | |
last_end = 0 | |
html_output = "" | |
for entity in sorted(model_output, key=lambda x: x["start"]): | |
start, end, label = entity["start"], entity["end"], entity["label"] | |
html_output += input_string[last_end:start] | |
tooltip = classes.get(label, "") | |
html_output += f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>' | |
last_end = end | |
html_output += input_string[last_end:] | |
return html_output | |
# Function to anonymize entities | |
def anonymize_text(input_string, model_output): | |
anonymized_text = "" | |
last_end = 0 | |
for entity in sorted(model_output, key=lambda x: x["start"]): | |
start, end, label = entity["start"], entity["end"], entity["label"] | |
anonymized_text += input_string[last_end:start] | |
anonymized_text += ( | |
f'<span class="anonymized">[{classes.get(label, label)}]</span>' | |
) | |
last_end = end | |
anonymized_text += input_string[last_end:] | |
return anonymized_text | |
def merge_entities(ner_results): | |
merged_entities = [] | |
current_entity = None | |
for token in ner_results: | |
tag = token["entity"] | |
entity_type = tag.split("-")[-1] if "-" in tag else tag | |
token_start, token_end = token["start"], token["end"] | |
token_word = token["word"].replace("##", "") # Remove subword prefixes | |
# Start a new entity if necessary | |
if ( | |
tag.startswith("B-") | |
or current_entity is None | |
or current_entity["label"] != entity_type | |
): | |
if current_entity: | |
merged_entities.append(current_entity) | |
current_entity = { | |
"start": token_start, | |
"end": token_end, | |
"label": entity_type, | |
"word": token_word, | |
} | |
elif ( | |
tag.startswith("I-") | |
and current_entity | |
and current_entity["label"] == entity_type | |
): | |
# Extend the current entity | |
current_entity["end"] = token_end | |
current_entity["word"] += token_word | |
else: | |
# Handle misclassifications or gaps in tokens | |
if ( | |
current_entity | |
and token_start == current_entity["end"] | |
and current_entity["label"] == entity_type | |
): | |
current_entity["end"] = token_end | |
current_entity["word"] += token_word | |
else: | |
# Treat it as a new entity if the above conditions aren't met | |
if current_entity: | |
merged_entities.append(current_entity) | |
current_entity = { | |
"start": token_start, | |
"end": token_end, | |
"label": entity_type, | |
"word": token_word, | |
} | |
# Append the last entity | |
if current_entity: | |
merged_entities.append(current_entity) | |
return merged_entities | |
st.title("Legal NER") | |
st.markdown("<hr>", unsafe_allow_html=True) | |
uploaded_file = st.file_uploader("Upload a .txt file", type="txt") | |
if uploaded_file is not None: | |
try: | |
raw_content = uploaded_file.read() | |
detected = detect(raw_content) | |
encoding = detected["encoding"] | |
if encoding is None: | |
raise ValueError("Unable to detect file encoding.") | |
lines = raw_content.decode(encoding).splitlines() | |
anonymize_mode = st.checkbox("Anonymize") | |
st.markdown( | |
"<hr style='margin-top: 10px; margin-bottom: 20px;'>", | |
unsafe_allow_html=True, | |
) | |
anonymized_lines = [] | |
displayed_lines = [] | |
for line_number, line in enumerate(lines, start=1): | |
if line.strip(): | |
results = ner(line) | |
merged_results = merge_entities(results) | |
if anonymize_mode: | |
anonymized_text = anonymize_text(line, merged_results) | |
displayed_lines.append(anonymized_text) | |
plain_text = re.sub(r"<.*?>", "", anonymized_text) | |
anonymized_lines.append(plain_text.strip()) | |
else: | |
colored_html = color_substrings(line, merged_results) | |
st.markdown(f"{colored_html}", unsafe_allow_html=True) | |
else: | |
displayed_lines.append("<br>") | |
anonymized_lines.append("") | |
if anonymize_mode: | |
original_file_name = uploaded_file.name | |
download_file_name = f"Anon_{original_file_name}" | |
anonymized_content = "\n".join(anonymized_lines) | |
for displayed_line in displayed_lines: | |
st.markdown(f"{displayed_line}", unsafe_allow_html=True) | |
st.markdown("<hr>", unsafe_allow_html=True) | |
st.download_button( | |
label="Download Anonymized Text", | |
data=anonymized_content, | |
file_name=download_file_name, | |
mime="text/plain", | |
) | |
if not anonymize_mode: | |
st.markdown( | |
'<div class="tip"><strong>Tip:</strong> Hover over the colored words to see its class.</div>', | |
unsafe_allow_html=True, | |
) | |
except Exception as e: | |
st.error(f"An error occurred while processing the file: {e}") | |