import re import string import matplotlib.cm as cm import streamlit as st from charset_normalizer import detect from transformers import ( AutoModelForTokenClassification, AutoTokenizer, logging, pipeline, ) # Streamlit page setup st.set_page_config(page_title="Juristische NER", page_icon="⚖️", layout="wide") logging.set_verbosity(logging.ERROR) st.markdown( """ """, unsafe_allow_html=True, ) # Entity label mapping entity_labels = { "AN": "Rechtsbeistand", "EUN": "EUNorm", "GRT": "Gericht", "GS": "Norm", "INN": "Institution", "LD": "Land", "LDS": "Bezirk", "LIT": "Schrifttum", "MRK": "Marke", "ORG": "Organisation", "PER": "Person", "RR": "RichterIn", "RS": "Entscheidung", "ST": "Stadt", "STR": "Strasse", "UN": "Unternehmen", "VO": "Verordnung", "VS": "Richtlinie", "VT": "Vertrag", "RED": "Schwärzung", } # Color generator def generate_fixed_colors(keys, alpha=0.25): cmap = cm.get_cmap("tab20", len(keys)) rgba_colors = {} for i, key in enumerate(keys): r, g, b, _ = cmap(i) rgba = f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {alpha})" rgba_colors[key] = rgba return rgba_colors ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys())) # Caching model @st.cache_resource def load_ner_pipeline(): return pipeline( "ner", model=AutoModelForTokenClassification.from_pretrained("harshildarji/JuraNER"), tokenizer=AutoTokenizer.from_pretrained("harshildarji/JuraNER"), ) # Caching NER + merge per line @st.cache_data(show_spinner=False) def get_ner_merged_lines(text): ner = load_ner_pipeline() results = [] for line in text.splitlines(): if not line.strip(): results.append(("", [])) continue tokens = ner(line) merged = merge_entities(tokens) results.append((line, merged)) return results # Entity merging def merge_entities(entities): if not entities: return [] ents = sorted(entities, key=lambda e: e["index"]) merged = [ents[0].copy()] merged[0]["score_sum"] = ents[0]["score"] merged[0]["count"] = 1 for ent in ents[1:]: prev = merged[-1] if ent["index"] == prev["index"] + 1: tok = ent["word"] prev["word"] += tok[2:] if tok.startswith("##") else " " + tok prev["end"] = ent["end"] prev["index"] = ent["index"] prev["score_sum"] += ent["score"] prev["count"] += 1 else: prev["score"] = prev["score_sum"] / prev["count"] del prev["score_sum"] del prev["count"] new_ent = ent.copy() new_ent["score_sum"] = ent["score"] new_ent["count"] = 1 merged.append(new_ent) if "score_sum" in merged[-1]: merged[-1]["score"] = merged[-1]["score_sum"] / merged[-1]["count"] del merged[-1]["score_sum"] del merged[-1]["count"] final = [] for ent in merged: w = ent["word"].strip() w = re.sub(r"\s*\.\s*", ".", w) w = re.sub(r"\s*,\s*", ", ", w) w = re.sub(r"\s*/\s*", "/", w) w = w.strip(string.whitespace + string.punctuation) if len(w) > 1 and re.search(r"\w", w): cleaned = ent.copy() cleaned["word"] = w final.append(cleaned) return final # Highlighting def highlight_entities(line, merged_entities, threshold): html = "" last_end = 0 for ent in merged_entities: if ent["score"] < threshold: continue start, end = ent["start"], ent["end"] label = ent["entity"].split("-")[-1] label_desc = entity_labels.get(label, label) color = ENTITY_COLORS.get(label, "#cccccc") html += line[last_end:start] highlight_style = f"background-color:{color}; font-weight:600;" html += ( f'' f'{ent["word"]}{label_desc}' ) last_end = end html += line[last_end:] return html # UI st.markdown("#### Juristische Named Entity Recognition (NER)") uploaded_file = st.file_uploader("Bitte laden Sie eine .txt-Datei hoch:", type="txt") threshold = st.slider("Schwellenwert für das Modellvertrauen:", 0.0, 1.0, 0.8, 0.01) st.markdown("---") if uploaded_file: raw_bytes = uploaded_file.read() encoding = detect(raw_bytes)["encoding"] if encoding is None: st.error("Zeichenkodierung konnte nicht erkannt werden.") else: text = raw_bytes.decode(encoding) with st.spinner("Modell wird auf jede Zeile angewendet..."): merged_all_lines = get_ner_merged_lines(text) for line, merged in merged_all_lines: if not line.strip(): continue html_line = highlight_entities(line, merged, threshold) st.markdown( f'
{html_line}
', unsafe_allow_html=True, )