import re import os import warnings import matplotlib.colors as mcolors import matplotlib.pyplot as plt import streamlit as st from charset_normalizer import detect from transformers import ( AutoModelForTokenClassification, AutoTokenizer, logging, pipeline, ) warnings.simplefilter(action="ignore", category=Warning) logging.set_verbosity(logging.ERROR) st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide") st.markdown( """ """, unsafe_allow_html=True, ) # UI text for English and German. ui_text = { "EN": { "title": "Legal NER", "upload": "Upload a .txt file", "anonymize": "Anonymize", "select_entities": "Entity types to anonymize:", "download": "Download Anonymized Text", "tip": "Tip: Hover over the colored words to see its class.", "error": "An error occurred while processing the file: ", }, "DE": { "title": "Juristische NER", "upload": "Lade eine .txt-Datei hoch", "anonymize": "Anonymisieren", "select_entities": "Entitätstypen zur Anonymisierung:", "download": "Anonymisierten Text herunterladen", "tip": "Tipp: Fahre mit der Maus über die farbigen Wörter, um deren Klasse zu sehen.", "error": "Beim Verarbeiten der Datei ist ein Fehler aufgetreten: ", }, } col1, col2 = st.columns([4, 1]) with col2: lang = st.radio( "Language:", options=["EN", "DE"], horizontal=True, label_visibility="hidden", key="language_selector", ) with col1: st.title(ui_text[lang]["title"]) # Initialization for German Legal NER tkn = os.getenv("tkn") tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER", use_auth_token=tkn) model = AutoModelForTokenClassification.from_pretrained( "harshildarji/JuraNER", use_auth_token=tkn ) ner = pipeline("ner", model=model, tokenizer=tokenizer) # Define class labels for the model classes = { "AN": "Lawyer", "EUN": "European legal norm", "GRT": "Court", "GS": "Law", "INN": "Institution", "LD": "Country", "LDS": "Landscape", "LIT": "Legal literature", "MRK": "Brand", "ORG": "Organization", "PER": "Person", "RR": "Judge", "RS": "Court decision", "ST": "City", "STR": "Street", "UN": "Company", "VO": "Ordinance", "VS": "Regulation", "VT": "Contract", } ner_labels = list(classes.keys()) # Generate a list of colors for visualization def generate_colors(num_colors): cm = plt.get_cmap("tab20") colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)] return colors # Color substrings based on NER results def color_substrings(input_string, model_output): colors = generate_colors(len(ner_labels)) label_to_color = { label: colors[i % len(colors)] for i, label in enumerate(ner_labels) } last_end = 0 html_output = "" for entity in sorted(model_output, key=lambda x: x["start"]): start, end, label = entity["start"], entity["end"], entity["label"] html_output += input_string[last_end:start] tooltip = classes.get(label, "") html_output += ( f'' f'{input_string[start:end]}{tooltip}' ) last_end = end html_output += input_string[last_end:] return html_output # Selectively anonymize entities def anonymize_text(input_string, model_output, selected_entities=None): merged_model_output = [] sorted_entities = sorted(model_output, key=lambda x: x["start"]) if sorted_entities: current = sorted_entities[0] for entity in sorted_entities[1:]: if ( entity["label"] == current["label"] and input_string[current["end"] : entity["start"]].strip() == "" ): current["end"] = entity["end"] current["word"] = input_string[current["start"] : current["end"]] else: merged_model_output.append(current) current = entity merged_model_output.append(current) else: merged_model_output = sorted_entities anonymized_text = "" last_end = 0 colors = generate_colors(len(ner_labels)) label_to_color = { label: colors[i % len(colors)] for i, label in enumerate(ner_labels) } for entity in merged_model_output: start, end, label = entity["start"], entity["end"], entity["label"] anonymized_text += input_string[last_end:start] if selected_entities is None or label in selected_entities: anonymized_text += ( f'[{classes.get(label, label)}]' ) else: tooltip = classes.get(label, "") anonymized_text += ( f'' f'{input_string[start:end]}{tooltip}' ) last_end = end anonymized_text += input_string[last_end:] return anonymized_text def merge_entities(ner_results): merged_entities = [] current_entity = None for token in ner_results: tag = token["entity"] entity_type = tag.split("-")[-1] if "-" in tag else tag token_start, token_end = token["start"], token["end"] token_word = token["word"].replace("##", "") # Remove subword prefixes if ( tag.startswith("B-") or current_entity is None or current_entity["label"] != entity_type ): if current_entity: merged_entities.append(current_entity) current_entity = { "start": token_start, "end": token_end, "label": entity_type, "word": token_word, } elif ( tag.startswith("I-") and current_entity and current_entity["label"] == entity_type ): current_entity["end"] = token_end current_entity["word"] += token_word else: if ( current_entity and token_start == current_entity["end"] and current_entity["label"] == entity_type ): current_entity["end"] = token_end current_entity["word"] += token_word else: if current_entity: merged_entities.append(current_entity) current_entity = { "start": token_start, "end": token_end, "label": entity_type, "word": token_word, } if current_entity: merged_entities.append(current_entity) return merged_entities uploaded_file = st.file_uploader(ui_text[lang]["upload"], type="txt") if uploaded_file is not None: try: raw_content = uploaded_file.read() detected = detect(raw_content) encoding = detected["encoding"] if encoding is None: raise ValueError("Unable to detect file encoding.") lines = raw_content.decode(encoding).splitlines() line_results = [] for line in lines: if line.strip(): results = ner(line) merged_results = merge_entities(results) line_results.append(merged_results) else: line_results.append([]) anonymize_mode = st.checkbox(ui_text[lang]["anonymize"]) selected_entities = None if anonymize_mode: detected_entity_tags = set() for merged_results in line_results: for entity in merged_results: detected_entity_tags.add(entity["label"]) inverse_classes = {v: k for k, v in classes.items()} detected_options = sorted([classes[tag] for tag in detected_entity_tags]) selected_options = st.multiselect( ui_text[lang]["select_entities"], options=detected_options, default=detected_options, ) selected_entities = [ inverse_classes[options] for options in selected_options ] st.markdown( "