Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import re | |
import string | |
import matplotlib.cm as cm | |
import streamlit as st | |
from charset_normalizer import detect | |
from transformers import ( | |
AutoModelForTokenClassification, | |
AutoTokenizer, | |
logging, | |
pipeline, | |
) | |
# Streamlit page setup | |
st.set_page_config(page_title="Juristische NER", page_icon="⚖️", layout="wide") | |
logging.set_verbosity(logging.ERROR) | |
st.markdown( | |
""" | |
<style> | |
.block-container { | |
padding-top: 1rem; | |
padding-bottom: 5rem; | |
padding-left: 3rem; | |
padding-right: 3rem; | |
} | |
header, footer {visibility: hidden;} | |
.entity { | |
position: relative; | |
display: inline-block; | |
background-color: transparent; | |
font-weight: normal; | |
cursor: help; | |
} | |
.entity .tooltip { | |
visibility: hidden; | |
background-color: #333; | |
color: #fff; | |
text-align: center; | |
border-radius: 4px; | |
padding: 2px 6px; | |
position: absolute; | |
z-index: 1; | |
bottom: 125%; | |
left: 50%; | |
transform: translateX(-50%); | |
white-space: nowrap; | |
opacity: 0; | |
transition: opacity 0.05s; | |
font-size: 11px; | |
} | |
.entity:hover .tooltip { | |
visibility: visible; | |
opacity: 1; | |
} | |
.entity.marked { | |
background-color: rgba(255, 230, 0, 0.4); | |
line-height: 1.3; | |
padding: 0 1px; | |
border-radius: 0px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Entity label mapping | |
entity_labels = { | |
"AN": "Rechtsbeistand", | |
"EUN": "EUNorm", | |
"GRT": "Gericht", | |
"GS": "Norm", | |
"INN": "Institution", | |
"LD": "Land", | |
"LDS": "Bezirk", | |
"LIT": "Schrifttum", | |
"MRK": "Marke", | |
"ORG": "Organisation", | |
"PER": "Person", | |
"RR": "RichterIn", | |
"RS": "Entscheidung", | |
"ST": "Stadt", | |
"STR": "Strasse", | |
"UN": "Unternehmen", | |
"VO": "Verordnung", | |
"VS": "Richtlinie", | |
"VT": "Vertrag", | |
"RED": "Schwärzung", | |
} | |
# Color generator | |
def generate_fixed_colors(keys, alpha=0.25): | |
cmap = cm.get_cmap("tab20", len(keys)) | |
rgba_colors = {} | |
for i, key in enumerate(keys): | |
r, g, b, _ = cmap(i) | |
rgba = f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {alpha})" | |
rgba_colors[key] = rgba | |
return rgba_colors | |
ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys())) | |
# Caching model | |
def load_ner_pipeline(): | |
return pipeline( | |
"ner", | |
model=AutoModelForTokenClassification.from_pretrained("harshildarji/JuraNER"), | |
tokenizer=AutoTokenizer.from_pretrained("harshildarji/JuraNER"), | |
) | |
# Caching NER + merge per line | |
def get_ner_merged_lines(text): | |
ner = load_ner_pipeline() | |
results = [] | |
for line in text.splitlines(): | |
if not line.strip(): | |
results.append(("", [])) | |
continue | |
tokens = ner(line) | |
merged = merge_entities(tokens) | |
results.append((line, merged)) | |
return results | |
# Entity merging | |
def merge_entities(entities): | |
if not entities: | |
return [] | |
ents = sorted(entities, key=lambda e: e["index"]) | |
merged = [ents[0].copy()] | |
merged[0]["score_sum"] = ents[0]["score"] | |
merged[0]["count"] = 1 | |
for ent in ents[1:]: | |
prev = merged[-1] | |
if ent["index"] == prev["index"] + 1: | |
tok = ent["word"] | |
prev["word"] += tok[2:] if tok.startswith("##") else " " + tok | |
prev["end"] = ent["end"] | |
prev["index"] = ent["index"] | |
prev["score_sum"] += ent["score"] | |
prev["count"] += 1 | |
else: | |
prev["score"] = prev["score_sum"] / prev["count"] | |
del prev["score_sum"] | |
del prev["count"] | |
new_ent = ent.copy() | |
new_ent["score_sum"] = ent["score"] | |
new_ent["count"] = 1 | |
merged.append(new_ent) | |
if "score_sum" in merged[-1]: | |
merged[-1]["score"] = merged[-1]["score_sum"] / merged[-1]["count"] | |
del merged[-1]["score_sum"] | |
del merged[-1]["count"] | |
final = [] | |
for ent in merged: | |
w = ent["word"].strip() | |
w = re.sub(r"\s*\.\s*", ".", w) | |
w = re.sub(r"\s*,\s*", ", ", w) | |
w = re.sub(r"\s*/\s*", "/", w) | |
w = w.strip(string.whitespace + string.punctuation) | |
if len(w) > 1 and re.search(r"\w", w): | |
cleaned = ent.copy() | |
cleaned["word"] = w | |
final.append(cleaned) | |
return final | |
# Highlighting | |
def highlight_entities(line, merged_entities, threshold): | |
html = "" | |
last_end = 0 | |
for ent in merged_entities: | |
if ent["score"] < threshold: | |
continue | |
start, end = ent["start"], ent["end"] | |
label = ent["entity"].split("-")[-1] | |
label_desc = entity_labels.get(label, label) | |
color = ENTITY_COLORS.get(label, "#cccccc") | |
html += line[last_end:start] | |
highlight_style = f"background-color:{color}; font-weight:600;" | |
html += ( | |
f'<span class="entity marked" style="{highlight_style}">' | |
f'{ent["word"]}<span class="tooltip">{label_desc}</span></span>' | |
) | |
last_end = end | |
html += line[last_end:] | |
return html | |
# UI | |
st.markdown("#### Juristische Named Entity Recognition (NER)") | |
uploaded_file = st.file_uploader("Bitte laden Sie eine .txt-Datei hoch:", type="txt") | |
threshold = st.slider("Schwellenwert für das Modellvertrauen:", 0.0, 1.0, 0.8, 0.01) | |
st.markdown("---") | |
if uploaded_file: | |
raw_bytes = uploaded_file.read() | |
encoding = detect(raw_bytes)["encoding"] | |
if encoding is None: | |
st.error("Zeichenkodierung konnte nicht erkannt werden.") | |
else: | |
text = raw_bytes.decode(encoding) | |
with st.spinner("Modell wird auf jede Zeile angewendet..."): | |
merged_all_lines = get_ner_merged_lines(text) | |
for line, merged in merged_all_lines: | |
if not line.strip(): | |
continue | |
html_line = highlight_entities(line, merged, threshold) | |
st.markdown( | |
f'<div style="margin-bottom:0.8rem; line-height:1.7;">{html_line}</div>', | |
unsafe_allow_html=True, | |
) | |