Sankey
Browse files- app.py +26 -4
- data/indicator_cause_sentence_metadata.tsv +0 -0
- plot.py +99 -15
app.py
CHANGED
@@ -4,9 +4,13 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification
|
|
4 |
from annotated_text import annotated_text
|
5 |
import pandas as pd
|
6 |
import plotly.express as px
|
7 |
-
from plot import indicator_chart, causes_chart,
|
8 |
import os
|
9 |
|
|
|
|
|
|
|
|
|
10 |
# Load the trained model and tokenizer
|
11 |
model_directory = "norygano/causalBERT"
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
|
@@ -30,7 +34,7 @@ st.markdown("[Model](https://huggingface.co/norygano/causalBERT) | [Data](https:
|
|
30 |
st.write("Tags indicators and causes in explicit attributions of causality. GER only (atm)")
|
31 |
|
32 |
# Create tabs
|
33 |
-
tab1, tab2, tab3, tab4 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter"])
|
34 |
|
35 |
# Prompt Tab
|
36 |
with tab1:
|
@@ -99,5 +103,23 @@ with tab3:
|
|
99 |
|
100 |
with tab4:
|
101 |
st.write("## Scatter")
|
102 |
-
fig_scatter =
|
103 |
-
st.plotly_chart(fig_scatter, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from annotated_text import annotated_text
|
5 |
import pandas as pd
|
6 |
import plotly.express as px
|
7 |
+
from plot import indicator_chart, causes_chart, scatter, sankey
|
8 |
import os
|
9 |
|
10 |
+
# Define initial threshold values at the top of the script
|
11 |
+
default_cause_threshold = 20
|
12 |
+
default_indicator_threshold = 3
|
13 |
+
|
14 |
# Load the trained model and tokenizer
|
15 |
model_directory = "norygano/causalBERT"
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
|
|
|
34 |
st.write("Tags indicators and causes in explicit attributions of causality. GER only (atm)")
|
35 |
|
36 |
# Create tabs
|
37 |
+
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"])
|
38 |
|
39 |
# Prompt Tab
|
40 |
with tab1:
|
|
|
103 |
|
104 |
with tab4:
|
105 |
st.write("## Scatter")
|
106 |
+
fig_scatter = scatter()
|
107 |
+
st.plotly_chart(fig_scatter, use_container_width=True)
|
108 |
+
|
109 |
+
with tab5:
|
110 |
+
st.write("## Sankey")
|
111 |
+
|
112 |
+
# Fixed height for the Sankey chart container
|
113 |
+
with st.container():
|
114 |
+
# Retrieve slider values and generate the diagram
|
115 |
+
cause_threshold = st.session_state.get("cause_threshold", default_cause_threshold)
|
116 |
+
indicator_threshold = st.session_state.get("indicator_threshold", default_indicator_threshold)
|
117 |
+
|
118 |
+
fig_sankey = sankey(cause_threshold=cause_threshold, indicator_threshold=indicator_threshold)
|
119 |
+
st.plotly_chart(fig_sankey, use_container_width=True)
|
120 |
+
|
121 |
+
# Place sliders below the chart container
|
122 |
+
with st.container():
|
123 |
+
st.write("Adjust thresholds for Sankey diagram:")
|
124 |
+
cause_threshold = st.slider("Cause Threshold", min_value=1, max_value=100, value=default_cause_threshold, key="cause_threshold")
|
125 |
+
indicator_threshold = st.slider("Indicator Threshold", min_value=1, max_value=100, value=default_indicator_threshold, key="indicator_threshold")
|
data/indicator_cause_sentence_metadata.tsv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
plot.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import pandas as pd
|
2 |
import plotly.express as px
|
|
|
3 |
import os
|
4 |
import umap
|
5 |
|
@@ -64,7 +65,7 @@ def indicator_chart(chart_type='overall'):
|
|
64 |
|
65 |
fig.update_layout(
|
66 |
xaxis=dict(showline=True),
|
67 |
-
yaxis=dict(showticklabels=True, title=''),
|
68 |
bargap=0.05,
|
69 |
showlegend=(chart_type == 'individual')
|
70 |
)
|
@@ -107,17 +108,17 @@ def causes_chart():
|
|
107 |
|
108 |
return fig
|
109 |
|
110 |
-
def
|
111 |
data_file = os.path.join('data', 'feature_matrix.tsv')
|
112 |
df = pd.read_csv(data_file, sep='\t')
|
113 |
|
114 |
-
# Exclude sentences without any indicators
|
115 |
indicator_columns = [col for col in df.columns if col.startswith('indicator_')]
|
116 |
cause_columns = [col for col in df.columns if col.startswith('cause_')]
|
117 |
modality_columns = [col for col in df.columns if col.startswith('modality_')]
|
118 |
|
119 |
df_filtered = df[(df[indicator_columns].sum(axis=1) > 0) |
|
120 |
-
|
121 |
|
122 |
# Exclude indicator '!besprechen'
|
123 |
indicator_columns = [col for col in indicator_columns if 'indicator_!besprechen' not in col]
|
@@ -129,30 +130,26 @@ def scatter_plot(include_modality=False):
|
|
129 |
# Further filter to exclude entries without any valid indicators
|
130 |
df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0]
|
131 |
|
132 |
-
# Exclude non-feature columns
|
133 |
columns_to_drop = ['subfolder']
|
134 |
if not include_modality:
|
135 |
columns_to_drop += modality_columns # Drop modality columns if not included
|
136 |
|
137 |
features = df_filtered.drop(columns=columns_to_drop)
|
138 |
-
|
139 |
-
# Fill NaN values with 0 for the feature matrix
|
140 |
features_clean = features.fillna(0)
|
141 |
|
142 |
-
#
|
143 |
metadata = df_filtered[['subfolder']].copy()
|
144 |
-
# Remove the 'indicator_' prefix for indicators and ensure only indicators with at least 10 occurrences are included
|
145 |
metadata['indicator'] = df_filtered[indicators_to_keep].apply(lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]), axis=1)
|
146 |
-
# Collect all non-zero causes as a string (multiple causes per sentence)
|
147 |
metadata['cause'] = df_filtered[cause_columns].apply(lambda row: ', '.join([cause.replace('cause_', '') for cause in cause_columns if row[cause] > 0]), axis=1)
|
148 |
|
149 |
-
#
|
150 |
-
reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, metric='cosine')
|
151 |
reduced_features = reducer.fit_transform(features_clean)
|
152 |
df_reduced = pd.DataFrame(reduced_features, columns=['Component 1', 'Component 2'])
|
153 |
df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1)
|
154 |
|
155 |
-
# Plotting the scatter plot
|
156 |
hover_data = {'cause'}
|
157 |
if include_modality:
|
158 |
hover_data['Modality'] = True
|
@@ -161,17 +158,104 @@ def scatter_plot(include_modality=False):
|
|
161 |
df_reduced,
|
162 |
x='Component 1',
|
163 |
y='Component 2',
|
164 |
-
color='subfolder',
|
|
|
165 |
hover_data=hover_data,
|
166 |
labels={'Component 1': 'UMAP Dim 1', 'Component 2': 'UMAP Dim 2'},
|
167 |
color_discrete_sequence=px.colors.qualitative.D3
|
168 |
)
|
169 |
|
|
|
|
|
|
|
|
|
|
|
170 |
fig.update_layout(
|
171 |
xaxis=dict(showgrid=False),
|
172 |
yaxis=dict(showgrid=False),
|
173 |
-
showlegend=True
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
)
|
175 |
|
176 |
return fig
|
177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import plotly.express as px
|
3 |
+
import plotly.graph_objects as go
|
4 |
import os
|
5 |
import umap
|
6 |
|
|
|
65 |
|
66 |
fig.update_layout(
|
67 |
xaxis=dict(showline=True),
|
68 |
+
yaxis=dict(showticklabels=True, title='', tickformat=".0%" if chart_type == 'overall' else None),
|
69 |
bargap=0.05,
|
70 |
showlegend=(chart_type == 'individual')
|
71 |
)
|
|
|
108 |
|
109 |
return fig
|
110 |
|
111 |
+
def scatter(include_modality=False):
|
112 |
data_file = os.path.join('data', 'feature_matrix.tsv')
|
113 |
df = pd.read_csv(data_file, sep='\t')
|
114 |
|
115 |
+
# Exclude sentences without any indicators, causes, or modalities (if included)
|
116 |
indicator_columns = [col for col in df.columns if col.startswith('indicator_')]
|
117 |
cause_columns = [col for col in df.columns if col.startswith('cause_')]
|
118 |
modality_columns = [col for col in df.columns if col.startswith('modality_')]
|
119 |
|
120 |
df_filtered = df[(df[indicator_columns].sum(axis=1) > 0) |
|
121 |
+
(df[cause_columns].sum(axis=1) > 0)]
|
122 |
|
123 |
# Exclude indicator '!besprechen'
|
124 |
indicator_columns = [col for col in indicator_columns if 'indicator_!besprechen' not in col]
|
|
|
130 |
# Further filter to exclude entries without any valid indicators
|
131 |
df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0]
|
132 |
|
133 |
+
# Exclude non-feature columns for dimensionality reduction
|
134 |
columns_to_drop = ['subfolder']
|
135 |
if not include_modality:
|
136 |
columns_to_drop += modality_columns # Drop modality columns if not included
|
137 |
|
138 |
features = df_filtered.drop(columns=columns_to_drop)
|
|
|
|
|
139 |
features_clean = features.fillna(0)
|
140 |
|
141 |
+
# Prepare metadata
|
142 |
metadata = df_filtered[['subfolder']].copy()
|
|
|
143 |
metadata['indicator'] = df_filtered[indicators_to_keep].apply(lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]), axis=1)
|
|
|
144 |
metadata['cause'] = df_filtered[cause_columns].apply(lambda row: ', '.join([cause.replace('cause_', '') for cause in cause_columns if row[cause] > 0]), axis=1)
|
145 |
|
146 |
+
# UMAP dimensionality reduction
|
147 |
+
reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, n_jobs=1, metric='cosine')
|
148 |
reduced_features = reducer.fit_transform(features_clean)
|
149 |
df_reduced = pd.DataFrame(reduced_features, columns=['Component 1', 'Component 2'])
|
150 |
df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1)
|
151 |
|
152 |
+
# Plotting the scatter plot
|
153 |
hover_data = {'cause'}
|
154 |
if include_modality:
|
155 |
hover_data['Modality'] = True
|
|
|
158 |
df_reduced,
|
159 |
x='Component 1',
|
160 |
y='Component 2',
|
161 |
+
color='subfolder', # Only subfolder colors will show in the legend
|
162 |
+
symbol='indicator', # Symbols for indicators, without showing in legend
|
163 |
hover_data=hover_data,
|
164 |
labels={'Component 1': 'UMAP Dim 1', 'Component 2': 'UMAP Dim 2'},
|
165 |
color_discrete_sequence=px.colors.qualitative.D3
|
166 |
)
|
167 |
|
168 |
+
# Hide the legend for all symbol traces (indicator-based traces)
|
169 |
+
for trace in fig.data:
|
170 |
+
if trace.marker.symbol is not None: # This targets symbol traces
|
171 |
+
trace.showlegend = False
|
172 |
+
|
173 |
fig.update_layout(
|
174 |
xaxis=dict(showgrid=False),
|
175 |
yaxis=dict(showgrid=False),
|
176 |
+
showlegend=True, # Show only the subfolder legend
|
177 |
+
legend=dict(
|
178 |
+
title="Term", # Adjust title to indicate the subfolder legend
|
179 |
+
yanchor="top",
|
180 |
+
xanchor="left",
|
181 |
+
borderwidth=1,
|
182 |
+
),
|
183 |
)
|
184 |
|
185 |
return fig
|
186 |
|
187 |
+
def sankey(cause_threshold=10, indicator_threshold=5):
|
188 |
+
# Load the data
|
189 |
+
data_file = os.path.join('data', 'indicator_cause_sentence_metadata.tsv')
|
190 |
+
df = pd.read_csv(data_file, sep='\t')
|
191 |
+
|
192 |
+
# Remove rows with NaN values in 'cause', 'indicator', or 'subfolder' columns
|
193 |
+
df = df.dropna(subset=['cause', 'indicator', 'subfolder'])
|
194 |
+
|
195 |
+
# Strip '_nk' from 'subfolder' values
|
196 |
+
df['subfolder'] = df['subfolder'].str.replace('_nk', '')
|
197 |
+
|
198 |
+
# Calculate overall counts for each cause and indicator
|
199 |
+
cause_counts = df['cause'].value_counts()
|
200 |
+
indicator_counts = df['indicator'].value_counts()
|
201 |
+
|
202 |
+
# Filter causes and indicators that meet their respective thresholds
|
203 |
+
valid_causes = cause_counts[cause_counts >= cause_threshold].index
|
204 |
+
valid_indicators = indicator_counts[indicator_counts >= indicator_threshold].index
|
205 |
+
|
206 |
+
# Filter the DataFrame to include only rows with causes and indicators that meet the thresholds
|
207 |
+
df_filtered = df[(df['cause'].isin(valid_causes)) & (df['indicator'].isin(valid_indicators))]
|
208 |
+
|
209 |
+
# Calculate pair counts for cause -> indicator and indicator -> subfolder
|
210 |
+
cause_indicator_counts = df_filtered.groupby(['cause', 'indicator']).size().reset_index(name='count')
|
211 |
+
indicator_subfolder_counts = df_filtered.groupby(['indicator', 'subfolder']).size().reset_index(name='count')
|
212 |
+
|
213 |
+
# Generate unique labels for Sankey nodes, including all causes, indicators, and subfolders
|
214 |
+
causes = df_filtered['cause'].unique()
|
215 |
+
indicators = df_filtered['indicator'].unique()
|
216 |
+
subfolders = df_filtered['subfolder'].unique()
|
217 |
+
all_labels = list(causes) + list(indicators) + list(subfolders)
|
218 |
+
|
219 |
+
# Mapping of each label to an index for Sankey node
|
220 |
+
label_to_index = {label: idx for idx, label in enumerate(all_labels)}
|
221 |
+
|
222 |
+
# Define sources, targets, and values for the Sankey diagram
|
223 |
+
sources = []
|
224 |
+
targets = []
|
225 |
+
values = []
|
226 |
+
|
227 |
+
# Add cause -> indicator links
|
228 |
+
for _, row in cause_indicator_counts.iterrows():
|
229 |
+
if row['cause'] in label_to_index and row['indicator'] in label_to_index:
|
230 |
+
sources.append(label_to_index[row['cause']])
|
231 |
+
targets.append(label_to_index[row['indicator']])
|
232 |
+
values.append(row['count'])
|
233 |
+
|
234 |
+
# Add indicator -> subfolder links
|
235 |
+
for _, row in indicator_subfolder_counts.iterrows():
|
236 |
+
if row['indicator'] in label_to_index and row['subfolder'] in label_to_index:
|
237 |
+
sources.append(label_to_index[row['indicator']])
|
238 |
+
targets.append(label_to_index[row['subfolder']])
|
239 |
+
values.append(row['count'])
|
240 |
+
|
241 |
+
fig = go.Figure(data=[go.Sankey(
|
242 |
+
node=dict(
|
243 |
+
pad=15,
|
244 |
+
thickness=20,
|
245 |
+
line=dict(color="black", width=0.5),
|
246 |
+
label=all_labels,
|
247 |
+
),
|
248 |
+
link=dict(
|
249 |
+
source=sources,
|
250 |
+
target=targets,
|
251 |
+
value=values
|
252 |
+
)
|
253 |
+
)])
|
254 |
+
|
255 |
+
fig.update_layout(
|
256 |
+
autosize=False, # Disable automatic resizing
|
257 |
+
width=500, # Fixed width
|
258 |
+
height=500, # Fixed height
|
259 |
+
)
|
260 |
+
|
261 |
+
return fig
|