import os import warnings import matplotlib.colors as mcolors import matplotlib.pyplot as plt import streamlit as st from charset_normalizer import detect from transformers import ( AutoModelForTokenClassification, AutoTokenizer, logging, pipeline, ) warnings.simplefilter(action="ignore", category=Warning) logging.set_verbosity(logging.ERROR) st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide") st.markdown( """ """, unsafe_allow_html=True, ) # Initialization for German Legal NER tkn = os.getenv("tkn") tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraBERT", use_auth_token=tkn) model = AutoModelForTokenClassification.from_pretrained( "harshildarji/JuraBERT", use_auth_token=tkn ) ner = pipeline("ner", model=model, tokenizer=tokenizer) # Define class labels for the model classes = { "AN": "Lawyer", "EUN": "European legal norm", "GRT": "Court", "GS": "Law", "INN": "Institution", "LD": "Country", "LDS": "Landscape", "LIT": "Legal literature", "MRK": "Brand", "ORG": "Organization", "PER": "Person", "RR": "Judge", "RS": "Court decision", "ST": "City", "STR": "Street", "UN": "Company", "VO": "Ordinance", "VS": "Regulation", "VT": "Contract", } ner_labels = list(classes.keys()) # Function to generate a list of colors for visualization def generate_colors(num_colors): cm = plt.get_cmap("tab20") colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)] return colors # Function to color substrings based on NER results def color_substrings(input_string, model_output): colors = generate_colors(len(ner_labels)) label_to_color = { label: colors[i % len(colors)] for i, label in enumerate(ner_labels) } last_end = 0 html_output = "" for entity in sorted(model_output, key=lambda x: x["start"]): start, end, label = entity["start"], entity["end"], entity["label"] html_output += input_string[last_end:start] tooltip = classes.get(label, "") html_output += f'{input_string[start:end]}{tooltip}' last_end = end html_output += input_string[last_end:] return html_output # Function to anonymize entities def anonymize_text(input_string, model_output): anonymized_text = "" last_end = 0 for entity in sorted(model_output, key=lambda x: x["start"]): start, end, label = entity["start"], entity["end"], entity["label"] anonymized_text += input_string[last_end:start] anonymized_text += ( f'[{classes.get(label, label)}]' ) last_end = end anonymized_text += input_string[last_end:] return anonymized_text def merge_entities(ner_results): merged_entities = [] current_entity = None for token in ner_results: tag = token["entity"] entity_type = tag.split("-")[-1] if "-" in tag else tag token_start, token_end = token["start"], token["end"] token_word = token["word"].replace("##", "") # Remove subword prefixes # Start a new entity if necessary if ( tag.startswith("B-") or current_entity is None or current_entity["label"] != entity_type ): if current_entity: merged_entities.append(current_entity) current_entity = { "start": token_start, "end": token_end, "label": entity_type, "word": token_word, } elif ( tag.startswith("I-") and current_entity and current_entity["label"] == entity_type ): # Extend the current entity current_entity["end"] = token_end current_entity["word"] += token_word else: # Handle misclassifications or gaps in tokens if ( current_entity and token_start == current_entity["end"] and current_entity["label"] == entity_type ): current_entity["end"] = token_end current_entity["word"] += token_word else: # Treat it as a new entity if the above conditions aren't met if current_entity: merged_entities.append(current_entity) current_entity = { "start": token_start, "end": token_end, "label": entity_type, "word": token_word, } # Append the last entity if current_entity: merged_entities.append(current_entity) return merged_entities st.title("Legal NER") st.markdown("
", unsafe_allow_html=True) uploaded_file = st.file_uploader("Upload a .txt file", type="txt") if uploaded_file is not None: try: # Read raw content of the file raw_content = uploaded_file.read() # Dynamically detect encoding detected = detect(raw_content) encoding = detected["encoding"] if encoding is None: raise ValueError("Unable to detect file encoding.") # Decode file content with the detected encoding lines = raw_content.decode(encoding).splitlines() anonymize_mode = st.checkbox("Anonymize") st.markdown( "
", unsafe_allow_html=True, ) for line_number, line in enumerate(lines, start=1): if line.strip(): results = ner(line) merged_results = merge_entities(results) if anonymize_mode: anonymized_text = anonymize_text(line, merged_results) st.markdown(f"{anonymized_text}", unsafe_allow_html=True) else: colored_html = color_substrings(line, merged_results) st.markdown(f"{colored_html}", unsafe_allow_html=True) else: st.markdown("
", unsafe_allow_html=True) if not anonymize_mode: st.markdown( '
Tip: Hover over the colored words to see its class.
', unsafe_allow_html=True, ) except Exception as e: st.error(f"An error occurred while processing the file: {e}")