oliviercaron commited on
Commit
c949917
·
verified ·
1 Parent(s): 5ae51f8

Upload app2.py

Browse files
Files changed (1) hide show
  1. app2.py +343 -0
app2.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import streamlit as st
4
+ import polars as pl
5
+ from io import BytesIO, StringIO
6
+ from gliner import GLiNER
7
+ from gliner_file import run_ner
8
+ import time
9
+ import torch
10
+ import platform
11
+ from typing import List
12
+ from streamlit_tags import st_tags # Importing the st_tags component
13
+
14
+ # Streamlit page configuration
15
+ st.set_page_config(
16
+ page_title="GLiNER",
17
+ page_icon="🔥",
18
+ layout="wide",
19
+ initial_sidebar_state="expanded"
20
+ )
21
+
22
+ # Function to load data from the uploaded file
23
+ @st.cache_data
24
+ def load_data(file):
25
+ """
26
+ Loads an uploaded CSV or Excel file with resilient detection of delimiters and types.
27
+ """
28
+ with st.spinner("Loading data, please wait..."):
29
+ try:
30
+ _, file_ext = os.path.splitext(file.name)
31
+ if file_ext.lower() in [".xls", ".xlsx"]:
32
+ return load_excel(file)
33
+ elif file_ext.lower() == ".csv":
34
+ return load_csv(file)
35
+ else:
36
+ raise ValueError("Unsupported file format. Please upload a CSV or Excel file.")
37
+ except Exception as e:
38
+ st.error("Error loading data:")
39
+ st.error(str(e))
40
+ return None
41
+
42
+ def load_excel(file):
43
+ """
44
+ Loads an Excel file using `BytesIO` and `polars` for reduced latency.
45
+ """
46
+ try:
47
+ # Load the file into BytesIO for faster reading
48
+ file_bytes = BytesIO(file.read())
49
+
50
+ # Load the Excel file using `polars`
51
+ df = pl.read_excel(file_bytes, read_options={"ignore_errors": True})
52
+ return df
53
+ except Exception as e:
54
+ raise ValueError(f"Error reading the Excel file: {str(e)}")
55
+
56
+ def load_csv(file):
57
+ """
58
+ Loads a CSV file by detecting the delimiter and using the quote character to handle internal delimiters.
59
+ """
60
+ try:
61
+ file.seek(0) # Reset file pointer to ensure reading from the beginning
62
+ raw_data = file.read()
63
+
64
+ # Try decoding as UTF-8, else as Latin-1
65
+ try:
66
+ file_content = raw_data.decode('utf-8')
67
+ except UnicodeDecodeError:
68
+ try:
69
+ file_content = raw_data.decode('latin1')
70
+ except UnicodeDecodeError:
71
+ raise ValueError("Unable to decode the file. Ensure it is encoded in UTF-8 or Latin-1.")
72
+
73
+ # List of common delimiters
74
+ delimiters = [",", ";", "|", "\t", " "]
75
+
76
+ # Try each delimiter until one works
77
+ for delimiter in delimiters:
78
+ try:
79
+ # Read CSV with current delimiter and handle quoted fields
80
+ df = pl.read_csv(
81
+ StringIO(file_content),
82
+ separator=delimiter,
83
+ quote_char='"', # Handle internal delimiters with quotes
84
+ try_parse_dates=True,
85
+ ignore_errors=True, # Ignore errors for invalid values
86
+ truncate_ragged_lines=True
87
+ )
88
+ # Return the DataFrame if loading succeeds
89
+ return df
90
+ except Exception:
91
+ continue # Move to the next delimiter in case of error
92
+
93
+ # If no delimiter worked
94
+ raise ValueError("Unable to load the file with common delimiters.")
95
+ except Exception as e:
96
+ raise ValueError(f"Error reading the CSV file: {str(e)}")
97
+
98
+ # Function to load the GLiNER model
99
+ @st.cache_resource
100
+ def load_model():
101
+ """
102
+ Loads the GLiNER model into memory to avoid multiple reloads.
103
+ """
104
+ try:
105
+ gpu_available = torch.cuda.is_available()
106
+
107
+ with st.spinner("Loading the GLiNER model... Please wait."):
108
+ device = torch.device("cuda" if gpu_available else "cpu")
109
+ model = GLiNER.from_pretrained(
110
+ "urchade/gliner_multi-v2.1"
111
+ ).to(device)
112
+ model.eval()
113
+
114
+ if gpu_available:
115
+ device_name = torch.cuda.get_device_name(0)
116
+ st.success(f"GPU detected: {device_name}. Model loaded on GPU.")
117
+ else:
118
+ cpu_name = platform.processor()
119
+ st.warning(f"No GPU detected. Using CPU: {cpu_name}")
120
+
121
+ return model
122
+ except Exception as e:
123
+ st.error("Error loading the model:")
124
+ st.error(str(e))
125
+ return None
126
+
127
+ # Function to perform NER and update the user interface
128
+ def perform_ner(filtered_df, selected_column, labels_list, threshold):
129
+ """
130
+ Executes named entity recognition (NER) on the filtered data.
131
+ """
132
+ try:
133
+ texts_to_analyze = filtered_df[selected_column].to_list()
134
+ total_rows = len(texts_to_analyze)
135
+ ner_results_list = []
136
+
137
+ # Initialize progress bar and text
138
+ progress_bar = st.progress(0)
139
+ progress_text = st.empty()
140
+ start_time = time.time()
141
+
142
+ # Process each row individually to keep progress updates responsive
143
+ for index, text in enumerate(texts_to_analyze, 1):
144
+ if st.session_state.stop_processing:
145
+ progress_text.text("Processing stopped by user.")
146
+ break
147
+
148
+ ner_results = run_ner(
149
+ st.session_state.gliner_model,
150
+ [text],
151
+ labels_list,
152
+ threshold=threshold
153
+ )
154
+ ner_results_list.append(ner_results)
155
+
156
+ # Update progress bar and text after each row
157
+ progress = index / total_rows
158
+ elapsed_time = time.time() - start_time
159
+ progress_bar.progress(progress)
160
+ progress_text.text(f"Progress: {index}/{total_rows} - {progress * 100:.0f}% (Elapsed time: {elapsed_time:.2f}s)")
161
+
162
+ # Add NER results to the DataFrame
163
+ for label in labels_list:
164
+ extracted_entities = []
165
+ for entities in ner_results_list:
166
+ texts = [entity["text"] for entity in entities[0] if entity["label"] == label]
167
+ concatenated_texts = ", ".join(texts) if texts else ""
168
+ extracted_entities.append(concatenated_texts)
169
+ filtered_df = filtered_df.with_columns(pl.Series(name=label, values=extracted_entities))
170
+
171
+ end_time = time.time()
172
+ st.success(f"Processing completed in {end_time - start_time:.2f} seconds.")
173
+
174
+ return filtered_df
175
+ except Exception as e:
176
+ st.error(f"Error during NER processing: {str(e)}")
177
+ return filtered_df
178
+
179
+ # Main function to run the Streamlit application
180
+ def main():
181
+ st.title("Use NER with GliNER on your data file")
182
+ st.markdown("Prototype v0.1")
183
+
184
+ # User instructions
185
+ st.write("""
186
+ This application performs named entity recognition (NER) on your text data using GLiNER.
187
+
188
+ **Instructions:**
189
+ 1. Upload a CSV or Excel file.
190
+ 2. Select the column containing the text to analyze.
191
+ 3. Filter the data if necessary.
192
+ 4. Enter the NER labels you wish to detect.
193
+ 5. Click "Start NER" to begin processing.
194
+ """)
195
+
196
+ # Initializing session state variables
197
+ if "stop_processing" not in st.session_state:
198
+ st.session_state.stop_processing = False
199
+ if "threshold" not in st.session_state:
200
+ st.session_state.threshold = 0.4
201
+ if "labels_list" not in st.session_state:
202
+ st.session_state.labels_list = []
203
+
204
+ # Load the model
205
+ st.session_state.gliner_model = load_model()
206
+ if st.session_state.gliner_model is None:
207
+ return
208
+
209
+ # File upload
210
+ uploaded_file = st.sidebar.file_uploader("Choose a file (CSV or Excel)")
211
+ if uploaded_file is None:
212
+ st.warning("Please upload a file to continue.")
213
+ return
214
+
215
+ # Loading data
216
+ df = load_data(uploaded_file)
217
+ if df is None:
218
+ return
219
+
220
+ # Column selection
221
+ selected_column = st.selectbox("Select the column containing the text:", df.columns)
222
+
223
+ # Data filtering
224
+ filter_text = st.text_input("Filter the column by text", "")
225
+ if filter_text:
226
+ filtered_df = df.filter(pl.col(selected_column).str.contains(f"(?i).*{filter_text}.*"))
227
+ else:
228
+ filtered_df = df
229
+
230
+ st.write("Filtered data preview:")
231
+
232
+ # Rows per page
233
+ rows_per_page = 100
234
+
235
+ # Calculate total rows and pages
236
+ total_rows = len(filtered_df)
237
+ total_pages = (total_rows - 1) // rows_per_page + 1
238
+
239
+ # Initialize current page in session_state
240
+ if "current_page" not in st.session_state:
241
+ st.session_state.current_page = 1
242
+
243
+ # Function to update page
244
+ def update_page(new_page):
245
+ st.session_state.current_page = new_page
246
+
247
+ # Pagination buttons
248
+ col1, col2, col3, col4, col5 = st.columns(5)
249
+
250
+ with col1:
251
+ first = st.button("⏮️ First")
252
+ with col2:
253
+ previous = st.button("⬅️ Previous")
254
+ with col3:
255
+ pass # Page number display will be done after
256
+ with col4:
257
+ next = st.button("Next ➡️")
258
+ with col5:
259
+ last = st.button("Last ⏭️")
260
+
261
+ # Button clicks management
262
+ if first:
263
+ update_page(1)
264
+ elif previous:
265
+ if st.session_state.current_page > 1:
266
+ update_page(st.session_state.current_page - 1)
267
+ elif next:
268
+ if st.session_state.current_page < total_pages:
269
+ update_page(st.session_state.current_page + 1)
270
+ elif last:
271
+ update_page(total_pages)
272
+
273
+ # Now display the page number after updating
274
+ with col3:
275
+ st.markdown(f"Page **{st.session_state.current_page}** of **{total_pages}**")
276
+
277
+ # Calculate indices for pagination
278
+ start_idx = (st.session_state.current_page - 1) * rows_per_page
279
+ end_idx = min(start_idx + rows_per_page, total_rows)
280
+
281
+ # Check if the filtered DataFrame is empty
282
+ if not filtered_df.is_empty():
283
+ # Retrieve current page data
284
+ current_page_data = filtered_df.slice(start_idx, end_idx - start_idx)
285
+ st.write(f"Displaying {start_idx + 1} to {end_idx} of {total_rows} rows")
286
+ st.dataframe(current_page_data.to_pandas(), use_container_width=True)
287
+ else:
288
+ st.warning("The filtered DataFrame is empty. Please check your filters.")
289
+
290
+ # Confidence threshold slider
291
+ st.slider("Set confidence threshold", 0.0, 1.0, st.session_state.threshold, 0.01, key="threshold")
292
+
293
+ # Buttons to start and stop NER
294
+ col1, col2 = st.columns(2)
295
+ with col1:
296
+ start_button = st.button("Start NER")
297
+ with col2:
298
+ stop_button = st.button("Stop")
299
+
300
+ if start_button:
301
+ st.session_state.stop_processing = False
302
+
303
+ if not st.session_state.labels_list:
304
+ st.warning("Please enter labels for NER.")
305
+ else:
306
+ # Run NER
307
+ updated_df = perform_ner(filtered_df, selected_column, st.session_state.labels_list, st.session_state.threshold)
308
+ st.write("**NER Results:**")
309
+ st.dataframe(updated_df.to_pandas(), use_container_width=True)
310
+
311
+ # Function to convert DataFrame to Excel
312
+ def to_excel(df):
313
+ output = BytesIO()
314
+ df.write_excel(output)
315
+ return output.getvalue()
316
+
317
+ # Function to convert DataFrame to CSV
318
+ def to_csv(df):
319
+ return df.write_csv().encode('utf-8')
320
+
321
+ # Download buttons for results
322
+ download_col1, download_col2 = st.columns(2)
323
+ with download_col1:
324
+ st.download_button(
325
+ label="📥 Download as Excel",
326
+ data=to_excel(updated_df),
327
+ file_name="ner_results.xlsx",
328
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
329
+ )
330
+ with download_col2:
331
+ st.download_button(
332
+ label="📥 Download as CSV",
333
+ data=to_csv(updated_df),
334
+ file_name="ner_results.csv",
335
+ mime="text/csv",
336
+ )
337
+
338
+ if stop_button:
339
+ st.session_state.stop_processing = True
340
+ st.warning("Processing stopped by user.")
341
+
342
+ if __name__ == "__main__":
343
+ main()