Harshil Darji commited on
Commit
56fe3c0
·
1 Parent(s): 347d235

Add app file

Browse files
Files changed (2) hide show
  1. app.py +308 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ import matplotlib.colors as mcolors
5
+ import matplotlib.pyplot as plt
6
+ import streamlit as st
7
+ from charset_normalizer import detect
8
+ from transformers import (
9
+ AutoModelForTokenClassification,
10
+ AutoTokenizer,
11
+ logging,
12
+ pipeline,
13
+ )
14
+
15
+ warnings.simplefilter(action="ignore", category=Warning)
16
+ logging.set_verbosity(logging.ERROR)
17
+
18
+ st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide")
19
+
20
+ st.markdown(
21
+ """
22
+ <style>
23
+ body {
24
+ font-family: 'Poppins', sans-serif;
25
+ background-color: #f4f4f8;
26
+ }
27
+ .header {
28
+ background-color: rgba(220, 219, 219, 0.25);
29
+ color: #000;
30
+ padding: 5px 0;
31
+ text-align: center;
32
+ border-radius: 7px;
33
+ margin-bottom: 13px;
34
+ border-bottom: 2px solid #333;
35
+ }
36
+ .container {
37
+ background-color: #fff;
38
+ padding: 30px;
39
+ border-radius: 10px;
40
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
41
+ width: 100%;
42
+ max-width: 1000px;
43
+ margin: 0 auto;
44
+ position: absolute;
45
+ top: 50%;
46
+ left: 50%;
47
+ transform: translate(-50%, -50%);
48
+ }
49
+ .btn-primary {
50
+ background-color: #5477d1;
51
+ border: none;
52
+ transition: background-color 0.3s, transform 0.2s;
53
+ border-radius: 25px;
54
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
55
+ }
56
+ .btn-primary:hover {
57
+ background-color: #4c6cbe;
58
+ transform: translateY(-1px);
59
+ }
60
+ h2 {
61
+ font-weight: 600;
62
+ font-size: 24px;
63
+ margin-bottom: 20px;
64
+ }
65
+ label {
66
+ font-weight: 500;
67
+ }
68
+ .tip {
69
+ background-color: rgba(180, 47, 109, 0.25);
70
+ padding: 7px;
71
+ border-radius: 7px;
72
+ display: inline-block;
73
+ margin-top: 15px;
74
+ margin-bottom: 15px;
75
+ }
76
+ .sec {
77
+ background-color: rgba(220, 219, 219, 0.10);
78
+ padding: 7px;
79
+ border-radius: 5px;
80
+ display: inline-block;
81
+ margin-top: 15px;
82
+ margin-bottom: 15px;
83
+ }
84
+ .tooltip {
85
+ position: relative;
86
+ display: inline-block;
87
+ cursor: pointer;
88
+ }
89
+ .tooltip .tooltiptext {
90
+ visibility: hidden;
91
+ width: 120px;
92
+ background-color: #6c757d;
93
+ color: #fff;
94
+ text-align: center;
95
+ border-radius: 3px;
96
+ padding: 3px;
97
+ position: absolute;
98
+ z-index: 1;
99
+ bottom: 125%;
100
+ left: 50%;
101
+ margin-left: -60px;
102
+ opacity: 0;
103
+ transition: opacity 0.3s;
104
+ }
105
+ .tooltip:hover .tooltiptext {
106
+ visibility: visible;
107
+ opacity: 1;
108
+ }
109
+ .anonymized {
110
+ background-color: #ffcccb;
111
+ color: #000;
112
+ font-weight: bold;
113
+ border-radius: 3px;
114
+ padding: 2px 4px;
115
+ }
116
+ </style>
117
+ """,
118
+ unsafe_allow_html=True,
119
+ )
120
+
121
+ # Initialization for German Legal NER
122
+ tkn = os.getenv("tkn")
123
+ tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraBERT", use_auth_token=tkn)
124
+ model = AutoModelForTokenClassification.from_pretrained(
125
+ "harshildarji/JuraBERT", use_auth_token=tkn
126
+ )
127
+ ner = pipeline("ner", model=model, tokenizer=tokenizer)
128
+
129
+ # Define class labels for the model
130
+ classes = {
131
+ "AN": "Lawyer",
132
+ "EUN": "European legal norm",
133
+ "GRT": "Court",
134
+ "GS": "Law",
135
+ "INN": "Institution",
136
+ "LD": "Country",
137
+ "LDS": "Landscape",
138
+ "LIT": "Legal literature",
139
+ "MRK": "Brand",
140
+ "ORG": "Organization",
141
+ "PER": "Person",
142
+ "RR": "Judge",
143
+ "RS": "Court decision",
144
+ "ST": "City",
145
+ "STR": "Street",
146
+ "UN": "Company",
147
+ "VO": "Ordinance",
148
+ "VS": "Regulation",
149
+ "VT": "Contract",
150
+ }
151
+ ner_labels = list(classes.keys())
152
+
153
+
154
+ # Function to generate a list of colors for visualization
155
+ def generate_colors(num_colors):
156
+ cm = plt.get_cmap("tab20")
157
+ colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)]
158
+ return colors
159
+
160
+
161
+ # Function to color substrings based on NER results
162
+ def color_substrings(input_string, model_output):
163
+ colors = generate_colors(len(ner_labels))
164
+ label_to_color = {
165
+ label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
166
+ }
167
+
168
+ last_end = 0
169
+ html_output = ""
170
+
171
+ for entity in sorted(model_output, key=lambda x: x["start"]):
172
+ start, end, label = entity["start"], entity["end"], entity["label"]
173
+ html_output += input_string[last_end:start]
174
+ tooltip = classes.get(label, "")
175
+ html_output += f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
176
+ last_end = end
177
+
178
+ html_output += input_string[last_end:]
179
+
180
+ return html_output
181
+
182
+
183
+ # Function to anonymize entities
184
+ def anonymize_text(input_string, model_output):
185
+ anonymized_text = ""
186
+ last_end = 0
187
+
188
+ for entity in sorted(model_output, key=lambda x: x["start"]):
189
+ start, end, label = entity["start"], entity["end"], entity["label"]
190
+ anonymized_text += input_string[last_end:start]
191
+ anonymized_text += (
192
+ f'<span class="anonymized">[{classes.get(label, label)}]</span>'
193
+ )
194
+ last_end = end
195
+
196
+ anonymized_text += input_string[last_end:]
197
+
198
+ return anonymized_text
199
+
200
+
201
+ def merge_entities(ner_results):
202
+ merged_entities = []
203
+ current_entity = None
204
+
205
+ for token in ner_results:
206
+ tag = token["entity"]
207
+ entity_type = tag.split("-")[-1] if "-" in tag else tag
208
+ token_start, token_end = token["start"], token["end"]
209
+ token_word = token["word"].replace("##", "") # Remove subword prefixes
210
+
211
+ # Start a new entity if necessary
212
+ if (
213
+ tag.startswith("B-")
214
+ or current_entity is None
215
+ or current_entity["label"] != entity_type
216
+ ):
217
+ if current_entity:
218
+ merged_entities.append(current_entity)
219
+ current_entity = {
220
+ "start": token_start,
221
+ "end": token_end,
222
+ "label": entity_type,
223
+ "word": token_word,
224
+ }
225
+ elif (
226
+ tag.startswith("I-")
227
+ and current_entity
228
+ and current_entity["label"] == entity_type
229
+ ):
230
+ # Extend the current entity
231
+ current_entity["end"] = token_end
232
+ current_entity["word"] += token_word
233
+ else:
234
+ # Handle misclassifications or gaps in tokens
235
+ if (
236
+ current_entity
237
+ and token_start == current_entity["end"]
238
+ and current_entity["label"] == entity_type
239
+ ):
240
+ current_entity["end"] = token_end
241
+ current_entity["word"] += token_word
242
+ else:
243
+ # Treat it as a new entity if the above conditions aren't met
244
+ if current_entity:
245
+ merged_entities.append(current_entity)
246
+ current_entity = {
247
+ "start": token_start,
248
+ "end": token_end,
249
+ "label": entity_type,
250
+ "word": token_word,
251
+ }
252
+
253
+ # Append the last entity
254
+ if current_entity:
255
+ merged_entities.append(current_entity)
256
+
257
+ return merged_entities
258
+
259
+
260
+ st.title("Legal NER")
261
+ st.markdown("<hr>", unsafe_allow_html=True)
262
+
263
+ uploaded_file = st.file_uploader("Upload a .txt file", type="txt")
264
+
265
+ if uploaded_file is not None:
266
+ try:
267
+ # Read raw content of the file
268
+ raw_content = uploaded_file.read()
269
+
270
+ # Dynamically detect encoding
271
+ detected = detect(raw_content)
272
+ encoding = detected["encoding"]
273
+
274
+ if encoding is None:
275
+ raise ValueError("Unable to detect file encoding.")
276
+
277
+ # Decode file content with the detected encoding
278
+ lines = raw_content.decode(encoding).splitlines()
279
+
280
+ anonymize_mode = st.checkbox("Anonymize")
281
+ st.markdown(
282
+ "<hr style='margin-top: 10px; margin-bottom: 20px;'>",
283
+ unsafe_allow_html=True,
284
+ )
285
+
286
+ for line_number, line in enumerate(lines, start=1):
287
+ if line.strip():
288
+ results = ner(line)
289
+ merged_results = merge_entities(results)
290
+
291
+ if anonymize_mode:
292
+ anonymized_text = anonymize_text(line, merged_results)
293
+ st.markdown(f"{anonymized_text}", unsafe_allow_html=True)
294
+ else:
295
+ colored_html = color_substrings(line, merged_results)
296
+ st.markdown(f"{colored_html}", unsafe_allow_html=True)
297
+
298
+ else:
299
+ st.markdown("<br>", unsafe_allow_html=True)
300
+
301
+ if not anonymize_mode:
302
+ st.markdown(
303
+ '<div class="tip"><strong>Tip:</strong> Hover over the colored words to see its class.</div>',
304
+ unsafe_allow_html=True,
305
+ )
306
+
307
+ except Exception as e:
308
+ st.error(f"An error occurred while processing the file: {e}")
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ matplotlib