norygano commited on
Commit
9525dec
·
1 Parent(s): 4f0f736
Files changed (3) hide show
  1. app.py +26 -4
  2. data/indicator_cause_sentence_metadata.tsv +0 -0
  3. plot.py +99 -15
app.py CHANGED
@@ -4,9 +4,13 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification
4
  from annotated_text import annotated_text
5
  import pandas as pd
6
  import plotly.express as px
7
- from plot import indicator_chart, causes_chart, scatter_plot
8
  import os
9
 
 
 
 
 
10
  # Load the trained model and tokenizer
11
  model_directory = "norygano/causalBERT"
12
  tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
@@ -30,7 +34,7 @@ st.markdown("[Model](https://huggingface.co/norygano/causalBERT) | [Data](https:
30
  st.write("Tags indicators and causes in explicit attributions of causality. GER only (atm)")
31
 
32
  # Create tabs
33
- tab1, tab2, tab3, tab4 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter"])
34
 
35
  # Prompt Tab
36
  with tab1:
@@ -99,5 +103,23 @@ with tab3:
99
 
100
  with tab4:
101
  st.write("## Scatter")
102
- fig_scatter = scatter_plot()
103
- st.plotly_chart(fig_scatter, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from annotated_text import annotated_text
5
  import pandas as pd
6
  import plotly.express as px
7
+ from plot import indicator_chart, causes_chart, scatter, sankey
8
  import os
9
 
10
+ # Define initial threshold values at the top of the script
11
+ default_cause_threshold = 20
12
+ default_indicator_threshold = 3
13
+
14
  # Load the trained model and tokenizer
15
  model_directory = "norygano/causalBERT"
16
  tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
 
34
  st.write("Tags indicators and causes in explicit attributions of causality. GER only (atm)")
35
 
36
  # Create tabs
37
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"])
38
 
39
  # Prompt Tab
40
  with tab1:
 
103
 
104
  with tab4:
105
  st.write("## Scatter")
106
+ fig_scatter = scatter()
107
+ st.plotly_chart(fig_scatter, use_container_width=True)
108
+
109
+ with tab5:
110
+ st.write("## Sankey")
111
+
112
+ # Fixed height for the Sankey chart container
113
+ with st.container():
114
+ # Retrieve slider values and generate the diagram
115
+ cause_threshold = st.session_state.get("cause_threshold", default_cause_threshold)
116
+ indicator_threshold = st.session_state.get("indicator_threshold", default_indicator_threshold)
117
+
118
+ fig_sankey = sankey(cause_threshold=cause_threshold, indicator_threshold=indicator_threshold)
119
+ st.plotly_chart(fig_sankey, use_container_width=True)
120
+
121
+ # Place sliders below the chart container
122
+ with st.container():
123
+ st.write("Adjust thresholds for Sankey diagram:")
124
+ cause_threshold = st.slider("Cause Threshold", min_value=1, max_value=100, value=default_cause_threshold, key="cause_threshold")
125
+ indicator_threshold = st.slider("Indicator Threshold", min_value=1, max_value=100, value=default_indicator_threshold, key="indicator_threshold")
data/indicator_cause_sentence_metadata.tsv CHANGED
The diff for this file is too large to render. See raw diff
 
plot.py CHANGED
@@ -1,5 +1,6 @@
1
  import pandas as pd
2
  import plotly.express as px
 
3
  import os
4
  import umap
5
 
@@ -64,7 +65,7 @@ def indicator_chart(chart_type='overall'):
64
 
65
  fig.update_layout(
66
  xaxis=dict(showline=True),
67
- yaxis=dict(showticklabels=True, title=''),
68
  bargap=0.05,
69
  showlegend=(chart_type == 'individual')
70
  )
@@ -107,17 +108,17 @@ def causes_chart():
107
 
108
  return fig
109
 
110
- def scatter_plot(include_modality=False):
111
  data_file = os.path.join('data', 'feature_matrix.tsv')
112
  df = pd.read_csv(data_file, sep='\t')
113
 
114
- # Exclude sentences without any indicators (all indicator columns are 0), causes, or modalities (if included)
115
  indicator_columns = [col for col in df.columns if col.startswith('indicator_')]
116
  cause_columns = [col for col in df.columns if col.startswith('cause_')]
117
  modality_columns = [col for col in df.columns if col.startswith('modality_')]
118
 
119
  df_filtered = df[(df[indicator_columns].sum(axis=1) > 0) |
120
- (df[cause_columns].sum(axis=1) > 0)]
121
 
122
  # Exclude indicator '!besprechen'
123
  indicator_columns = [col for col in indicator_columns if 'indicator_!besprechen' not in col]
@@ -129,30 +130,26 @@ def scatter_plot(include_modality=False):
129
  # Further filter to exclude entries without any valid indicators
130
  df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0]
131
 
132
- # Exclude non-feature columns (metadata and sentence text) for dimensionality reduction
133
  columns_to_drop = ['subfolder']
134
  if not include_modality:
135
  columns_to_drop += modality_columns # Drop modality columns if not included
136
 
137
  features = df_filtered.drop(columns=columns_to_drop)
138
-
139
- # Fill NaN values with 0 for the feature matrix
140
  features_clean = features.fillna(0)
141
 
142
- # Store the relevant metadata separately to ensure it is aligned correctly with the dimensionality reduction results
143
  metadata = df_filtered[['subfolder']].copy()
144
- # Remove the 'indicator_' prefix for indicators and ensure only indicators with at least 10 occurrences are included
145
  metadata['indicator'] = df_filtered[indicators_to_keep].apply(lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]), axis=1)
146
- # Collect all non-zero causes as a string (multiple causes per sentence)
147
  metadata['cause'] = df_filtered[cause_columns].apply(lambda row: ', '.join([cause.replace('cause_', '') for cause in cause_columns if row[cause] > 0]), axis=1)
148
 
149
- # Perform UMAP dimensionality reduction
150
- reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, metric='cosine')
151
  reduced_features = reducer.fit_transform(features_clean)
152
  df_reduced = pd.DataFrame(reduced_features, columns=['Component 1', 'Component 2'])
153
  df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1)
154
 
155
- # Plotting the scatter plot with Plotly Express
156
  hover_data = {'cause'}
157
  if include_modality:
158
  hover_data['Modality'] = True
@@ -161,17 +158,104 @@ def scatter_plot(include_modality=False):
161
  df_reduced,
162
  x='Component 1',
163
  y='Component 2',
164
- color='subfolder',
 
165
  hover_data=hover_data,
166
  labels={'Component 1': 'UMAP Dim 1', 'Component 2': 'UMAP Dim 2'},
167
  color_discrete_sequence=px.colors.qualitative.D3
168
  )
169
 
 
 
 
 
 
170
  fig.update_layout(
171
  xaxis=dict(showgrid=False),
172
  yaxis=dict(showgrid=False),
173
- showlegend=True
 
 
 
 
 
 
174
  )
175
 
176
  return fig
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import plotly.express as px
3
+ import plotly.graph_objects as go
4
  import os
5
  import umap
6
 
 
65
 
66
  fig.update_layout(
67
  xaxis=dict(showline=True),
68
+ yaxis=dict(showticklabels=True, title='', tickformat=".0%" if chart_type == 'overall' else None),
69
  bargap=0.05,
70
  showlegend=(chart_type == 'individual')
71
  )
 
108
 
109
  return fig
110
 
111
+ def scatter(include_modality=False):
112
  data_file = os.path.join('data', 'feature_matrix.tsv')
113
  df = pd.read_csv(data_file, sep='\t')
114
 
115
+ # Exclude sentences without any indicators, causes, or modalities (if included)
116
  indicator_columns = [col for col in df.columns if col.startswith('indicator_')]
117
  cause_columns = [col for col in df.columns if col.startswith('cause_')]
118
  modality_columns = [col for col in df.columns if col.startswith('modality_')]
119
 
120
  df_filtered = df[(df[indicator_columns].sum(axis=1) > 0) |
121
+ (df[cause_columns].sum(axis=1) > 0)]
122
 
123
  # Exclude indicator '!besprechen'
124
  indicator_columns = [col for col in indicator_columns if 'indicator_!besprechen' not in col]
 
130
  # Further filter to exclude entries without any valid indicators
131
  df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0]
132
 
133
+ # Exclude non-feature columns for dimensionality reduction
134
  columns_to_drop = ['subfolder']
135
  if not include_modality:
136
  columns_to_drop += modality_columns # Drop modality columns if not included
137
 
138
  features = df_filtered.drop(columns=columns_to_drop)
 
 
139
  features_clean = features.fillna(0)
140
 
141
+ # Prepare metadata
142
  metadata = df_filtered[['subfolder']].copy()
 
143
  metadata['indicator'] = df_filtered[indicators_to_keep].apply(lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]), axis=1)
 
144
  metadata['cause'] = df_filtered[cause_columns].apply(lambda row: ', '.join([cause.replace('cause_', '') for cause in cause_columns if row[cause] > 0]), axis=1)
145
 
146
+ # UMAP dimensionality reduction
147
+ reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, n_jobs=1, metric='cosine')
148
  reduced_features = reducer.fit_transform(features_clean)
149
  df_reduced = pd.DataFrame(reduced_features, columns=['Component 1', 'Component 2'])
150
  df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1)
151
 
152
+ # Plotting the scatter plot
153
  hover_data = {'cause'}
154
  if include_modality:
155
  hover_data['Modality'] = True
 
158
  df_reduced,
159
  x='Component 1',
160
  y='Component 2',
161
+ color='subfolder', # Only subfolder colors will show in the legend
162
+ symbol='indicator', # Symbols for indicators, without showing in legend
163
  hover_data=hover_data,
164
  labels={'Component 1': 'UMAP Dim 1', 'Component 2': 'UMAP Dim 2'},
165
  color_discrete_sequence=px.colors.qualitative.D3
166
  )
167
 
168
+ # Hide the legend for all symbol traces (indicator-based traces)
169
+ for trace in fig.data:
170
+ if trace.marker.symbol is not None: # This targets symbol traces
171
+ trace.showlegend = False
172
+
173
  fig.update_layout(
174
  xaxis=dict(showgrid=False),
175
  yaxis=dict(showgrid=False),
176
+ showlegend=True, # Show only the subfolder legend
177
+ legend=dict(
178
+ title="Term", # Adjust title to indicate the subfolder legend
179
+ yanchor="top",
180
+ xanchor="left",
181
+ borderwidth=1,
182
+ ),
183
  )
184
 
185
  return fig
186
 
187
+ def sankey(cause_threshold=10, indicator_threshold=5):
188
+ # Load the data
189
+ data_file = os.path.join('data', 'indicator_cause_sentence_metadata.tsv')
190
+ df = pd.read_csv(data_file, sep='\t')
191
+
192
+ # Remove rows with NaN values in 'cause', 'indicator', or 'subfolder' columns
193
+ df = df.dropna(subset=['cause', 'indicator', 'subfolder'])
194
+
195
+ # Strip '_nk' from 'subfolder' values
196
+ df['subfolder'] = df['subfolder'].str.replace('_nk', '')
197
+
198
+ # Calculate overall counts for each cause and indicator
199
+ cause_counts = df['cause'].value_counts()
200
+ indicator_counts = df['indicator'].value_counts()
201
+
202
+ # Filter causes and indicators that meet their respective thresholds
203
+ valid_causes = cause_counts[cause_counts >= cause_threshold].index
204
+ valid_indicators = indicator_counts[indicator_counts >= indicator_threshold].index
205
+
206
+ # Filter the DataFrame to include only rows with causes and indicators that meet the thresholds
207
+ df_filtered = df[(df['cause'].isin(valid_causes)) & (df['indicator'].isin(valid_indicators))]
208
+
209
+ # Calculate pair counts for cause -> indicator and indicator -> subfolder
210
+ cause_indicator_counts = df_filtered.groupby(['cause', 'indicator']).size().reset_index(name='count')
211
+ indicator_subfolder_counts = df_filtered.groupby(['indicator', 'subfolder']).size().reset_index(name='count')
212
+
213
+ # Generate unique labels for Sankey nodes, including all causes, indicators, and subfolders
214
+ causes = df_filtered['cause'].unique()
215
+ indicators = df_filtered['indicator'].unique()
216
+ subfolders = df_filtered['subfolder'].unique()
217
+ all_labels = list(causes) + list(indicators) + list(subfolders)
218
+
219
+ # Mapping of each label to an index for Sankey node
220
+ label_to_index = {label: idx for idx, label in enumerate(all_labels)}
221
+
222
+ # Define sources, targets, and values for the Sankey diagram
223
+ sources = []
224
+ targets = []
225
+ values = []
226
+
227
+ # Add cause -> indicator links
228
+ for _, row in cause_indicator_counts.iterrows():
229
+ if row['cause'] in label_to_index and row['indicator'] in label_to_index:
230
+ sources.append(label_to_index[row['cause']])
231
+ targets.append(label_to_index[row['indicator']])
232
+ values.append(row['count'])
233
+
234
+ # Add indicator -> subfolder links
235
+ for _, row in indicator_subfolder_counts.iterrows():
236
+ if row['indicator'] in label_to_index and row['subfolder'] in label_to_index:
237
+ sources.append(label_to_index[row['indicator']])
238
+ targets.append(label_to_index[row['subfolder']])
239
+ values.append(row['count'])
240
+
241
+ fig = go.Figure(data=[go.Sankey(
242
+ node=dict(
243
+ pad=15,
244
+ thickness=20,
245
+ line=dict(color="black", width=0.5),
246
+ label=all_labels,
247
+ ),
248
+ link=dict(
249
+ source=sources,
250
+ target=targets,
251
+ value=values
252
+ )
253
+ )])
254
+
255
+ fig.update_layout(
256
+ autosize=False, # Disable automatic resizing
257
+ width=500, # Fixed width
258
+ height=500, # Fixed height
259
+ )
260
+
261
+ return fig