import os import warnings import matplotlib.colors as mcolors import matplotlib.pyplot as plt import streamlit as st from charset_normalizer import detect from transformers import ( AutoModelForTokenClassification, AutoTokenizer, logging, pipeline, ) warnings.simplefilter(action="ignore", category=Warning) logging.set_verbosity(logging.ERROR) st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide") st.markdown( """ """, unsafe_allow_html=True, ) # Initialization for German Legal NER tkn = os.getenv("tkn") tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraBERT", use_auth_token=tkn) model = AutoModelForTokenClassification.from_pretrained( "harshildarji/JuraBERT", use_auth_token=tkn ) ner = pipeline("ner", model=model, tokenizer=tokenizer) # Define class labels for the model classes = { "AN": "Lawyer", "EUN": "European legal norm", "GRT": "Court", "GS": "Law", "INN": "Institution", "LD": "Country", "LDS": "Landscape", "LIT": "Legal literature", "MRK": "Brand", "ORG": "Organization", "PER": "Person", "RR": "Judge", "RS": "Court decision", "ST": "City", "STR": "Street", "UN": "Company", "VO": "Ordinance", "VS": "Regulation", "VT": "Contract", } ner_labels = list(classes.keys()) # Function to generate a list of colors for visualization def generate_colors(num_colors): cm = plt.get_cmap("tab20") colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)] return colors # Function to color substrings based on NER results def color_substrings(input_string, model_output): colors = generate_colors(len(ner_labels)) label_to_color = { label: colors[i % len(colors)] for i, label in enumerate(ner_labels) } last_end = 0 html_output = "" for entity in sorted(model_output, key=lambda x: x["start"]): start, end, label = entity["start"], entity["end"], entity["label"] html_output += input_string[last_end:start] tooltip = classes.get(label, "") html_output += f'{input_string[start:end]}{tooltip}' last_end = end html_output += input_string[last_end:] return html_output # Function to anonymize entities def anonymize_text(input_string, model_output): anonymized_text = "" last_end = 0 for entity in sorted(model_output, key=lambda x: x["start"]): start, end, label = entity["start"], entity["end"], entity["label"] anonymized_text += input_string[last_end:start] anonymized_text += ( f'[{classes.get(label, label)}]' ) last_end = end anonymized_text += input_string[last_end:] return anonymized_text def merge_entities(ner_results): merged_entities = [] current_entity = None for token in ner_results: tag = token["entity"] entity_type = tag.split("-")[-1] if "-" in tag else tag token_start, token_end = token["start"], token["end"] token_word = token["word"].replace("##", "") # Remove subword prefixes # Start a new entity if necessary if ( tag.startswith("B-") or current_entity is None or current_entity["label"] != entity_type ): if current_entity: merged_entities.append(current_entity) current_entity = { "start": token_start, "end": token_end, "label": entity_type, "word": token_word, } elif ( tag.startswith("I-") and current_entity and current_entity["label"] == entity_type ): # Extend the current entity current_entity["end"] = token_end current_entity["word"] += token_word else: # Handle misclassifications or gaps in tokens if ( current_entity and token_start == current_entity["end"] and current_entity["label"] == entity_type ): current_entity["end"] = token_end current_entity["word"] += token_word else: # Treat it as a new entity if the above conditions aren't met if current_entity: merged_entities.append(current_entity) current_entity = { "start": token_start, "end": token_end, "label": entity_type, "word": token_word, } # Append the last entity if current_entity: merged_entities.append(current_entity) return merged_entities st.title("Legal NER") st.markdown("