Juristische-NER / app.py
Harshil Darji
update app
4a3eaed
import re
import string
import matplotlib.cm as cm
import streamlit as st
from charset_normalizer import detect
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
logging,
pipeline,
)
# Streamlit page setup
st.set_page_config(page_title="Juristische NER", page_icon="⚖️", layout="wide")
logging.set_verbosity(logging.ERROR)
st.markdown(
"""
<style>
.block-container {
padding-top: 1rem;
padding-bottom: 5rem;
padding-left: 3rem;
padding-right: 3rem;
}
header, footer {visibility: hidden;}
.entity {
position: relative;
display: inline-block;
background-color: transparent;
font-weight: normal;
cursor: help;
}
.entity .tooltip {
visibility: hidden;
background-color: #333;
color: #fff;
text-align: center;
border-radius: 4px;
padding: 2px 6px;
position: absolute;
z-index: 1;
bottom: 125%;
left: 50%;
transform: translateX(-50%);
white-space: nowrap;
opacity: 0;
transition: opacity 0.05s;
font-size: 11px;
}
.entity:hover .tooltip {
visibility: visible;
opacity: 1;
}
.entity.marked {
background-color: rgba(255, 230, 0, 0.4);
line-height: 1.3;
padding: 0 1px;
border-radius: 0px;
}
</style>
""",
unsafe_allow_html=True,
)
# Entity label mapping
entity_labels = {
"AN": "Rechtsbeistand",
"EUN": "EUNorm",
"GRT": "Gericht",
"GS": "Norm",
"INN": "Institution",
"LD": "Land",
"LDS": "Bezirk",
"LIT": "Schrifttum",
"MRK": "Marke",
"ORG": "Organisation",
"PER": "Person",
"RR": "RichterIn",
"RS": "Entscheidung",
"ST": "Stadt",
"STR": "Strasse",
"UN": "Unternehmen",
"VO": "Verordnung",
"VS": "Richtlinie",
"VT": "Vertrag",
"RED": "Schwärzung",
}
# Color generator
def generate_fixed_colors(keys, alpha=0.25):
cmap = cm.get_cmap("tab20", len(keys))
rgba_colors = {}
for i, key in enumerate(keys):
r, g, b, _ = cmap(i)
rgba = f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {alpha})"
rgba_colors[key] = rgba
return rgba_colors
ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys()))
# Caching model
@st.cache_resource
def load_ner_pipeline():
return pipeline(
"ner",
model=AutoModelForTokenClassification.from_pretrained("harshildarji/JuraNER"),
tokenizer=AutoTokenizer.from_pretrained("harshildarji/JuraNER"),
)
# Caching NER + merge per line
@st.cache_data(show_spinner=False)
def get_ner_merged_lines(text):
ner = load_ner_pipeline()
results = []
for line in text.splitlines():
if not line.strip():
results.append(("", []))
continue
tokens = ner(line)
merged = merge_entities(tokens)
results.append((line, merged))
return results
# Entity merging
def merge_entities(entities):
if not entities:
return []
ents = sorted(entities, key=lambda e: e["index"])
merged = [ents[0].copy()]
merged[0]["score_sum"] = ents[0]["score"]
merged[0]["count"] = 1
for ent in ents[1:]:
prev = merged[-1]
if ent["index"] == prev["index"] + 1:
tok = ent["word"]
prev["word"] += tok[2:] if tok.startswith("##") else " " + tok
prev["end"] = ent["end"]
prev["index"] = ent["index"]
prev["score_sum"] += ent["score"]
prev["count"] += 1
else:
prev["score"] = prev["score_sum"] / prev["count"]
del prev["score_sum"]
del prev["count"]
new_ent = ent.copy()
new_ent["score_sum"] = ent["score"]
new_ent["count"] = 1
merged.append(new_ent)
if "score_sum" in merged[-1]:
merged[-1]["score"] = merged[-1]["score_sum"] / merged[-1]["count"]
del merged[-1]["score_sum"]
del merged[-1]["count"]
final = []
for ent in merged:
w = ent["word"].strip()
w = re.sub(r"\s*\.\s*", ".", w)
w = re.sub(r"\s*,\s*", ", ", w)
w = re.sub(r"\s*/\s*", "/", w)
w = w.strip(string.whitespace + string.punctuation)
if len(w) > 1 and re.search(r"\w", w):
cleaned = ent.copy()
cleaned["word"] = w
final.append(cleaned)
return final
# Highlighting
def highlight_entities(line, merged_entities, threshold):
html = ""
last_end = 0
for ent in merged_entities:
if ent["score"] < threshold:
continue
start, end = ent["start"], ent["end"]
label = ent["entity"].split("-")[-1]
label_desc = entity_labels.get(label, label)
color = ENTITY_COLORS.get(label, "#cccccc")
html += line[last_end:start]
highlight_style = f"background-color:{color}; font-weight:600;"
html += (
f'<span class="entity marked" style="{highlight_style}">'
f'{ent["word"]}<span class="tooltip">{label_desc}</span></span>'
)
last_end = end
html += line[last_end:]
return html
# UI
st.markdown("#### Juristische Named Entity Recognition (NER)")
uploaded_file = st.file_uploader("Bitte laden Sie eine .txt-Datei hoch:", type="txt")
threshold = st.slider("Schwellenwert für das Modellvertrauen:", 0.0, 1.0, 0.8, 0.01)
st.markdown("---")
if uploaded_file:
raw_bytes = uploaded_file.read()
encoding = detect(raw_bytes)["encoding"]
if encoding is None:
st.error("Zeichenkodierung konnte nicht erkannt werden.")
else:
text = raw_bytes.decode(encoding)
with st.spinner("Modell wird auf jede Zeile angewendet..."):
merged_all_lines = get_ner_merged_lines(text)
for line, merged in merged_all_lines:
if not line.strip():
continue
html_line = highlight_entities(line, merged, threshold)
st.markdown(
f'<div style="margin-bottom:0.8rem; line-height:1.7;">{html_line}</div>',
unsafe_allow_html=True,
)