Spaces:
Sleeping
Sleeping
Upload app2.py
Browse files
app2.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import csv
|
3 |
+
import streamlit as st
|
4 |
+
import polars as pl
|
5 |
+
from io import BytesIO, StringIO
|
6 |
+
from gliner import GLiNER
|
7 |
+
from gliner_file import run_ner
|
8 |
+
import time
|
9 |
+
import torch
|
10 |
+
import platform
|
11 |
+
from typing import List
|
12 |
+
from streamlit_tags import st_tags # Importing the st_tags component
|
13 |
+
|
14 |
+
# Streamlit page configuration
|
15 |
+
st.set_page_config(
|
16 |
+
page_title="GLiNER",
|
17 |
+
page_icon="🔥",
|
18 |
+
layout="wide",
|
19 |
+
initial_sidebar_state="expanded"
|
20 |
+
)
|
21 |
+
|
22 |
+
# Function to load data from the uploaded file
|
23 |
+
@st.cache_data
|
24 |
+
def load_data(file):
|
25 |
+
"""
|
26 |
+
Loads an uploaded CSV or Excel file with resilient detection of delimiters and types.
|
27 |
+
"""
|
28 |
+
with st.spinner("Loading data, please wait..."):
|
29 |
+
try:
|
30 |
+
_, file_ext = os.path.splitext(file.name)
|
31 |
+
if file_ext.lower() in [".xls", ".xlsx"]:
|
32 |
+
return load_excel(file)
|
33 |
+
elif file_ext.lower() == ".csv":
|
34 |
+
return load_csv(file)
|
35 |
+
else:
|
36 |
+
raise ValueError("Unsupported file format. Please upload a CSV or Excel file.")
|
37 |
+
except Exception as e:
|
38 |
+
st.error("Error loading data:")
|
39 |
+
st.error(str(e))
|
40 |
+
return None
|
41 |
+
|
42 |
+
def load_excel(file):
|
43 |
+
"""
|
44 |
+
Loads an Excel file using `BytesIO` and `polars` for reduced latency.
|
45 |
+
"""
|
46 |
+
try:
|
47 |
+
# Load the file into BytesIO for faster reading
|
48 |
+
file_bytes = BytesIO(file.read())
|
49 |
+
|
50 |
+
# Load the Excel file using `polars`
|
51 |
+
df = pl.read_excel(file_bytes, read_options={"ignore_errors": True})
|
52 |
+
return df
|
53 |
+
except Exception as e:
|
54 |
+
raise ValueError(f"Error reading the Excel file: {str(e)}")
|
55 |
+
|
56 |
+
def load_csv(file):
|
57 |
+
"""
|
58 |
+
Loads a CSV file by detecting the delimiter and using the quote character to handle internal delimiters.
|
59 |
+
"""
|
60 |
+
try:
|
61 |
+
file.seek(0) # Reset file pointer to ensure reading from the beginning
|
62 |
+
raw_data = file.read()
|
63 |
+
|
64 |
+
# Try decoding as UTF-8, else as Latin-1
|
65 |
+
try:
|
66 |
+
file_content = raw_data.decode('utf-8')
|
67 |
+
except UnicodeDecodeError:
|
68 |
+
try:
|
69 |
+
file_content = raw_data.decode('latin1')
|
70 |
+
except UnicodeDecodeError:
|
71 |
+
raise ValueError("Unable to decode the file. Ensure it is encoded in UTF-8 or Latin-1.")
|
72 |
+
|
73 |
+
# List of common delimiters
|
74 |
+
delimiters = [",", ";", "|", "\t", " "]
|
75 |
+
|
76 |
+
# Try each delimiter until one works
|
77 |
+
for delimiter in delimiters:
|
78 |
+
try:
|
79 |
+
# Read CSV with current delimiter and handle quoted fields
|
80 |
+
df = pl.read_csv(
|
81 |
+
StringIO(file_content),
|
82 |
+
separator=delimiter,
|
83 |
+
quote_char='"', # Handle internal delimiters with quotes
|
84 |
+
try_parse_dates=True,
|
85 |
+
ignore_errors=True, # Ignore errors for invalid values
|
86 |
+
truncate_ragged_lines=True
|
87 |
+
)
|
88 |
+
# Return the DataFrame if loading succeeds
|
89 |
+
return df
|
90 |
+
except Exception:
|
91 |
+
continue # Move to the next delimiter in case of error
|
92 |
+
|
93 |
+
# If no delimiter worked
|
94 |
+
raise ValueError("Unable to load the file with common delimiters.")
|
95 |
+
except Exception as e:
|
96 |
+
raise ValueError(f"Error reading the CSV file: {str(e)}")
|
97 |
+
|
98 |
+
# Function to load the GLiNER model
|
99 |
+
@st.cache_resource
|
100 |
+
def load_model():
|
101 |
+
"""
|
102 |
+
Loads the GLiNER model into memory to avoid multiple reloads.
|
103 |
+
"""
|
104 |
+
try:
|
105 |
+
gpu_available = torch.cuda.is_available()
|
106 |
+
|
107 |
+
with st.spinner("Loading the GLiNER model... Please wait."):
|
108 |
+
device = torch.device("cuda" if gpu_available else "cpu")
|
109 |
+
model = GLiNER.from_pretrained(
|
110 |
+
"urchade/gliner_multi-v2.1"
|
111 |
+
).to(device)
|
112 |
+
model.eval()
|
113 |
+
|
114 |
+
if gpu_available:
|
115 |
+
device_name = torch.cuda.get_device_name(0)
|
116 |
+
st.success(f"GPU detected: {device_name}. Model loaded on GPU.")
|
117 |
+
else:
|
118 |
+
cpu_name = platform.processor()
|
119 |
+
st.warning(f"No GPU detected. Using CPU: {cpu_name}")
|
120 |
+
|
121 |
+
return model
|
122 |
+
except Exception as e:
|
123 |
+
st.error("Error loading the model:")
|
124 |
+
st.error(str(e))
|
125 |
+
return None
|
126 |
+
|
127 |
+
# Function to perform NER and update the user interface
|
128 |
+
def perform_ner(filtered_df, selected_column, labels_list, threshold):
|
129 |
+
"""
|
130 |
+
Executes named entity recognition (NER) on the filtered data.
|
131 |
+
"""
|
132 |
+
try:
|
133 |
+
texts_to_analyze = filtered_df[selected_column].to_list()
|
134 |
+
total_rows = len(texts_to_analyze)
|
135 |
+
ner_results_list = []
|
136 |
+
|
137 |
+
# Initialize progress bar and text
|
138 |
+
progress_bar = st.progress(0)
|
139 |
+
progress_text = st.empty()
|
140 |
+
start_time = time.time()
|
141 |
+
|
142 |
+
# Process each row individually to keep progress updates responsive
|
143 |
+
for index, text in enumerate(texts_to_analyze, 1):
|
144 |
+
if st.session_state.stop_processing:
|
145 |
+
progress_text.text("Processing stopped by user.")
|
146 |
+
break
|
147 |
+
|
148 |
+
ner_results = run_ner(
|
149 |
+
st.session_state.gliner_model,
|
150 |
+
[text],
|
151 |
+
labels_list,
|
152 |
+
threshold=threshold
|
153 |
+
)
|
154 |
+
ner_results_list.append(ner_results)
|
155 |
+
|
156 |
+
# Update progress bar and text after each row
|
157 |
+
progress = index / total_rows
|
158 |
+
elapsed_time = time.time() - start_time
|
159 |
+
progress_bar.progress(progress)
|
160 |
+
progress_text.text(f"Progress: {index}/{total_rows} - {progress * 100:.0f}% (Elapsed time: {elapsed_time:.2f}s)")
|
161 |
+
|
162 |
+
# Add NER results to the DataFrame
|
163 |
+
for label in labels_list:
|
164 |
+
extracted_entities = []
|
165 |
+
for entities in ner_results_list:
|
166 |
+
texts = [entity["text"] for entity in entities[0] if entity["label"] == label]
|
167 |
+
concatenated_texts = ", ".join(texts) if texts else ""
|
168 |
+
extracted_entities.append(concatenated_texts)
|
169 |
+
filtered_df = filtered_df.with_columns(pl.Series(name=label, values=extracted_entities))
|
170 |
+
|
171 |
+
end_time = time.time()
|
172 |
+
st.success(f"Processing completed in {end_time - start_time:.2f} seconds.")
|
173 |
+
|
174 |
+
return filtered_df
|
175 |
+
except Exception as e:
|
176 |
+
st.error(f"Error during NER processing: {str(e)}")
|
177 |
+
return filtered_df
|
178 |
+
|
179 |
+
# Main function to run the Streamlit application
|
180 |
+
def main():
|
181 |
+
st.title("Use NER with GliNER on your data file")
|
182 |
+
st.markdown("Prototype v0.1")
|
183 |
+
|
184 |
+
# User instructions
|
185 |
+
st.write("""
|
186 |
+
This application performs named entity recognition (NER) on your text data using GLiNER.
|
187 |
+
|
188 |
+
**Instructions:**
|
189 |
+
1. Upload a CSV or Excel file.
|
190 |
+
2. Select the column containing the text to analyze.
|
191 |
+
3. Filter the data if necessary.
|
192 |
+
4. Enter the NER labels you wish to detect.
|
193 |
+
5. Click "Start NER" to begin processing.
|
194 |
+
""")
|
195 |
+
|
196 |
+
# Initializing session state variables
|
197 |
+
if "stop_processing" not in st.session_state:
|
198 |
+
st.session_state.stop_processing = False
|
199 |
+
if "threshold" not in st.session_state:
|
200 |
+
st.session_state.threshold = 0.4
|
201 |
+
if "labels_list" not in st.session_state:
|
202 |
+
st.session_state.labels_list = []
|
203 |
+
|
204 |
+
# Load the model
|
205 |
+
st.session_state.gliner_model = load_model()
|
206 |
+
if st.session_state.gliner_model is None:
|
207 |
+
return
|
208 |
+
|
209 |
+
# File upload
|
210 |
+
uploaded_file = st.sidebar.file_uploader("Choose a file (CSV or Excel)")
|
211 |
+
if uploaded_file is None:
|
212 |
+
st.warning("Please upload a file to continue.")
|
213 |
+
return
|
214 |
+
|
215 |
+
# Loading data
|
216 |
+
df = load_data(uploaded_file)
|
217 |
+
if df is None:
|
218 |
+
return
|
219 |
+
|
220 |
+
# Column selection
|
221 |
+
selected_column = st.selectbox("Select the column containing the text:", df.columns)
|
222 |
+
|
223 |
+
# Data filtering
|
224 |
+
filter_text = st.text_input("Filter the column by text", "")
|
225 |
+
if filter_text:
|
226 |
+
filtered_df = df.filter(pl.col(selected_column).str.contains(f"(?i).*{filter_text}.*"))
|
227 |
+
else:
|
228 |
+
filtered_df = df
|
229 |
+
|
230 |
+
st.write("Filtered data preview:")
|
231 |
+
|
232 |
+
# Rows per page
|
233 |
+
rows_per_page = 100
|
234 |
+
|
235 |
+
# Calculate total rows and pages
|
236 |
+
total_rows = len(filtered_df)
|
237 |
+
total_pages = (total_rows - 1) // rows_per_page + 1
|
238 |
+
|
239 |
+
# Initialize current page in session_state
|
240 |
+
if "current_page" not in st.session_state:
|
241 |
+
st.session_state.current_page = 1
|
242 |
+
|
243 |
+
# Function to update page
|
244 |
+
def update_page(new_page):
|
245 |
+
st.session_state.current_page = new_page
|
246 |
+
|
247 |
+
# Pagination buttons
|
248 |
+
col1, col2, col3, col4, col5 = st.columns(5)
|
249 |
+
|
250 |
+
with col1:
|
251 |
+
first = st.button("⏮️ First")
|
252 |
+
with col2:
|
253 |
+
previous = st.button("⬅️ Previous")
|
254 |
+
with col3:
|
255 |
+
pass # Page number display will be done after
|
256 |
+
with col4:
|
257 |
+
next = st.button("Next ➡️")
|
258 |
+
with col5:
|
259 |
+
last = st.button("Last ⏭️")
|
260 |
+
|
261 |
+
# Button clicks management
|
262 |
+
if first:
|
263 |
+
update_page(1)
|
264 |
+
elif previous:
|
265 |
+
if st.session_state.current_page > 1:
|
266 |
+
update_page(st.session_state.current_page - 1)
|
267 |
+
elif next:
|
268 |
+
if st.session_state.current_page < total_pages:
|
269 |
+
update_page(st.session_state.current_page + 1)
|
270 |
+
elif last:
|
271 |
+
update_page(total_pages)
|
272 |
+
|
273 |
+
# Now display the page number after updating
|
274 |
+
with col3:
|
275 |
+
st.markdown(f"Page **{st.session_state.current_page}** of **{total_pages}**")
|
276 |
+
|
277 |
+
# Calculate indices for pagination
|
278 |
+
start_idx = (st.session_state.current_page - 1) * rows_per_page
|
279 |
+
end_idx = min(start_idx + rows_per_page, total_rows)
|
280 |
+
|
281 |
+
# Check if the filtered DataFrame is empty
|
282 |
+
if not filtered_df.is_empty():
|
283 |
+
# Retrieve current page data
|
284 |
+
current_page_data = filtered_df.slice(start_idx, end_idx - start_idx)
|
285 |
+
st.write(f"Displaying {start_idx + 1} to {end_idx} of {total_rows} rows")
|
286 |
+
st.dataframe(current_page_data.to_pandas(), use_container_width=True)
|
287 |
+
else:
|
288 |
+
st.warning("The filtered DataFrame is empty. Please check your filters.")
|
289 |
+
|
290 |
+
# Confidence threshold slider
|
291 |
+
st.slider("Set confidence threshold", 0.0, 1.0, st.session_state.threshold, 0.01, key="threshold")
|
292 |
+
|
293 |
+
# Buttons to start and stop NER
|
294 |
+
col1, col2 = st.columns(2)
|
295 |
+
with col1:
|
296 |
+
start_button = st.button("Start NER")
|
297 |
+
with col2:
|
298 |
+
stop_button = st.button("Stop")
|
299 |
+
|
300 |
+
if start_button:
|
301 |
+
st.session_state.stop_processing = False
|
302 |
+
|
303 |
+
if not st.session_state.labels_list:
|
304 |
+
st.warning("Please enter labels for NER.")
|
305 |
+
else:
|
306 |
+
# Run NER
|
307 |
+
updated_df = perform_ner(filtered_df, selected_column, st.session_state.labels_list, st.session_state.threshold)
|
308 |
+
st.write("**NER Results:**")
|
309 |
+
st.dataframe(updated_df.to_pandas(), use_container_width=True)
|
310 |
+
|
311 |
+
# Function to convert DataFrame to Excel
|
312 |
+
def to_excel(df):
|
313 |
+
output = BytesIO()
|
314 |
+
df.write_excel(output)
|
315 |
+
return output.getvalue()
|
316 |
+
|
317 |
+
# Function to convert DataFrame to CSV
|
318 |
+
def to_csv(df):
|
319 |
+
return df.write_csv().encode('utf-8')
|
320 |
+
|
321 |
+
# Download buttons for results
|
322 |
+
download_col1, download_col2 = st.columns(2)
|
323 |
+
with download_col1:
|
324 |
+
st.download_button(
|
325 |
+
label="📥 Download as Excel",
|
326 |
+
data=to_excel(updated_df),
|
327 |
+
file_name="ner_results.xlsx",
|
328 |
+
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
329 |
+
)
|
330 |
+
with download_col2:
|
331 |
+
st.download_button(
|
332 |
+
label="📥 Download as CSV",
|
333 |
+
data=to_csv(updated_df),
|
334 |
+
file_name="ner_results.csv",
|
335 |
+
mime="text/csv",
|
336 |
+
)
|
337 |
+
|
338 |
+
if stop_button:
|
339 |
+
st.session_state.stop_processing = True
|
340 |
+
st.warning("Processing stopped by user.")
|
341 |
+
|
342 |
+
if __name__ == "__main__":
|
343 |
+
main()
|