norygano commited on
Commit
7e996cc
·
1 Parent(s): 6b22889

Added plot.py

Browse files
Files changed (3) hide show
  1. app.py +65 -29
  2. data/feature_matrix.tsv +0 -0
  3. plot.py +265 -238
app.py CHANGED
@@ -2,14 +2,17 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForTokenClassification
4
  from annotated_text import annotated_text
5
- import pandas as pd
6
- import plotly.express as px
7
- from plot import indicator_chart, causes_chart, scatter, sankey
8
  import os
 
9
 
10
  # Define initial threshold values at the top of the script
11
  default_cause_threshold = 20
12
  default_indicator_threshold = 15
 
 
 
 
 
13
 
14
  # Load the trained model and tokenizer
15
  model_directory = "norygano/causalBERT"
@@ -30,16 +33,16 @@ st.markdown(
30
  """,
31
  unsafe_allow_html=True
32
  )
33
- st.markdown("[Model](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv) | [Project](https://www.uni-trier.de/universitaet/fachbereiche-faecher/fachbereich-ii/faecher/germanistik/professurenfachteile/germanistische-linguistik/professoren/prof-dr-martin-wengeler/kontroverse-diskurse/individium-gesellschaft)")
34
- st.write("Tags indicators and causes in explicit attributions of causality.")
35
 
36
  # Create tabs
37
- tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"])
38
 
39
  # Prompt Tab
40
  with tab1:
41
  sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
42
- "Autos stehen im Verdacht, Waldsterben zu verursachen.",
43
  "Fußball führt zu Waldschäden.",
44
  "Haustüren tragen zum Betonsterben bei.",
45
  ]), placeholder="German only (currently)")
@@ -82,36 +85,69 @@ with tab1:
82
  annotated_text(*annotations)
83
  st.write("---")
84
 
85
- # Research Insights Tab
 
86
  with tab2:
87
- # Overall
88
- st.subheader("Overall")
89
- fig_overall = indicator_chart(chart_type='overall')
90
- st.plotly_chart(fig_overall, use_container_width=True)
91
-
92
- # Individual Indicators Chart
93
- st.subheader("Individual")
94
- fig_individual = indicator_chart(chart_type='individual')
95
- st.plotly_chart(fig_individual, use_container_width=True)
 
 
 
 
 
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  with tab3:
98
- fig_causes = causes_chart()
99
- st.plotly_chart(fig_causes, use_container_width=True)
 
 
 
 
 
 
 
 
100
 
 
101
  with tab4:
102
- fig_scatter = scatter()
103
  st.plotly_chart(fig_scatter, use_container_width=True)
104
 
 
105
  with tab5:
106
- # Fixed height for the Sankey chart container
107
  with st.container():
108
- # Retrieve slider values and generate the diagram
109
- cause_threshold = st.session_state.get("cause_threshold", default_cause_threshold)
110
- indicator_threshold = st.session_state.get("indicator_threshold", default_indicator_threshold)
111
- fig_sankey = sankey(cause_threshold=cause_threshold, indicator_threshold=indicator_threshold)
 
 
112
  st.plotly_chart(fig_sankey, use_container_width=True)
113
-
114
- # Place sliders below the chart container
115
  with st.container():
116
- cause_threshold = st.slider("Cause >", min_value=1, max_value=100, value=default_cause_threshold, key="cause_threshold")
117
- indicator_threshold = st.slider("Indicator >", min_value=1, max_value=100, value=default_indicator_threshold, key="indicator_threshold")
 
 
 
 
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForTokenClassification
4
  from annotated_text import annotated_text
 
 
 
5
  import os
6
+ from plot import Plot # Assuming the class is saved in diagram_generator.py
7
 
8
  # Define initial threshold values at the top of the script
9
  default_cause_threshold = 20
10
  default_indicator_threshold = 15
11
+ default_cause_threshold_sankey = 20
12
+ default_indicator_threshold_sankey = 15
13
+
14
+ # Initialize Plots
15
+ plot = Plot()
16
 
17
  # Load the trained model and tokenizer
18
  model_directory = "norygano/causalBERT"
 
33
  """,
34
  unsafe_allow_html=True
35
  )
36
+ st.markdown("[Weights](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv) | [Project](https://www.uni-trier.de/universitaet/fachbereiche-faecher/fachbereich-ii/faecher/germanistik/professurenfachteile/germanistische-linguistik/professoren/prof-dr-martin-wengeler/kontroverse-diskurse/individium-gesellschaft)")
37
+ st.write("Indicators and causes in explicit attributions of causality.")
38
 
39
  # Create tabs
40
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(["Model", "Indicators", "Causes", "Scatter", "Sankey"])
41
 
42
  # Prompt Tab
43
  with tab1:
44
  sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
45
+ "Autos stehen im Verdacht, Waldsterben verursacht zu haben.",
46
  "Fußball führt zu Waldschäden.",
47
  "Haustüren tragen zum Betonsterben bei.",
48
  ]), placeholder="German only (currently)")
 
85
  annotated_text(*annotations)
86
  st.write("---")
87
 
88
+
89
+ # Indicator Tab
90
  with tab2:
91
+ selected_chart_type = st.radio(
92
+ "Type",
93
+ options=['Total', 'Year', 'Individual'],
94
+ horizontal=True,
95
+ )
96
+
97
+ # Display the chart in a container
98
+ with st.container():
99
+ if selected_chart_type == 'Individual':
100
+ # Retrieve slider value from session state or use default
101
+ individual_threshold = st.session_state.get("individual_threshold", default_indicator_threshold)
102
+ fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower(), individual_threshold=individual_threshold)
103
+ else:
104
+ fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower())
105
+ st.plotly_chart(fig, use_container_width=True)
106
 
107
+ # Display the slider below the chart container for 'Individual' type
108
+ if selected_chart_type == 'Individual':
109
+ with st.container():
110
+ individual_threshold = st.slider(
111
+ "Indicator >=",
112
+ min_value=1,
113
+ max_value=95,
114
+ value=default_indicator_threshold,
115
+ key="individual_threshold"
116
+ )
117
+
118
+ # Causes Tab
119
  with tab3:
120
+ # Create a container for the chart and place the slider below it
121
+ with st.container():
122
+ # Display the chart first
123
+ fig_causes = plot.get_causes_chart(min_value=st.session_state.get("cause_threshold_causes", default_cause_threshold))
124
+ st.plotly_chart(fig_causes, use_container_width=True)
125
+
126
+ # Place the slider below the chart with a unique key
127
+ cause_threshold_causes = st.slider(
128
+ "Cause >=", min_value=1, max_value=75, value=default_cause_threshold, key="cause_threshold_causes"
129
+ )
130
 
131
+ # Scatter Tab
132
  with tab4:
133
+ fig_scatter = plot.scatter()
134
  st.plotly_chart(fig_scatter, use_container_width=True)
135
 
136
+ # Sankey Tab
137
  with tab5:
 
138
  with st.container():
139
+ # Use the unique Sankey threshold variables in session state
140
+ cause_threshold_sankey = st.session_state.get("cause_threshold_sankey", default_cause_threshold_sankey)
141
+ indicator_threshold_sankey = st.session_state.get("indicator_threshold_sankey", default_indicator_threshold_sankey)
142
+
143
+ # Generate the Sankey diagram with the new Sankey-specific thresholds
144
+ fig_sankey = plot.sankey(cause_threshold=cause_threshold_sankey, indicator_threshold=indicator_threshold_sankey)
145
  st.plotly_chart(fig_sankey, use_container_width=True)
146
+ # Place sliders below the chart container with unique keys for the Sankey tab
 
147
  with st.container():
148
+ cause_threshold_sankey = st.slider(
149
+ "Cause >=", min_value=1, max_value=100, value=default_cause_threshold_sankey, key="cause_threshold_sankey"
150
+ )
151
+ indicator_threshold_sankey = st.slider(
152
+ "Indicator >=", min_value=1, max_value=100, value=default_indicator_threshold_sankey, key="indicator_threshold_sankey"
153
+ )
data/feature_matrix.tsv CHANGED
The diff for this file is too large to render. See raw diff
 
plot.py CHANGED
@@ -3,259 +3,286 @@ import plotly.express as px
3
  import plotly.graph_objects as go
4
  import os
5
  import umap
6
-
7
- def indicator_chart(chart_type='overall'):
8
- data_file = os.path.join('data', 'indicator_overview.tsv')
9
- df = pd.read_csv(data_file, sep='\t')
10
-
11
- if chart_type == 'overall':
12
- df_filtered = df[df['Indicator'] == 'Total with Indicators'].copy()
13
- total_sentences_per_subfolder = df.groupby('Subfolder')['Total Sentences'].first().to_dict()
14
- df_filtered['Total Sentences'] = df_filtered['Subfolder'].map(total_sentences_per_subfolder)
15
- df_filtered['Indicator_Share'] = df_filtered['Count'] / df_filtered['Total Sentences']
16
- df_filtered['Indicator_Share_Text'] = (df_filtered['Indicator_Share'] * 100).round(2).astype(str) + '%'
 
17
 
18
- fig = px.bar(
19
- df_filtered,
20
- x='Subfolder',
21
- y='Indicator_Share',
22
- labels={'Indicator_Share': 'Share of Sentences with Indicators', 'Subfolder': ''},
23
- color='Subfolder',
24
- text='Indicator_Share_Text',
25
- color_discrete_sequence=px.colors.qualitative.D3,
26
- custom_data=['Total Sentences', 'Count']
27
- )
28
 
29
- fig.update_traces(
30
- hovertemplate=(
31
- '<b>%{x}</b><br>' +
32
- 'Share with Indicators: %{y:.1%}<br>' +
33
- 'Total Sentences: %{customdata[0]}<br>' +
34
- 'Sentences with Indicators: %{customdata[1]}<extra></extra>'
35
- ),
36
- textposition='inside',
37
- texttemplate='%{text}',
38
- textfont=dict(color='rgb(255, 255, 255)'),
39
- insidetextanchor='middle',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
 
41
 
42
- elif chart_type == 'individual':
43
- min_value = 5
44
- exclude_indicators = ['!besprechen']
45
- df_filtered = df[~df['Indicator'].isin(['Total with Indicators', 'None'] + exclude_indicators)].copy()
46
- indicators_meeting_threshold = df_filtered[df_filtered['Count'] >= min_value]['Indicator'].unique()
47
- df_filtered = df_filtered[df_filtered['Indicator'].isin(indicators_meeting_threshold)]
48
- df_filtered['Indicator'] = df_filtered['Indicator'].str.capitalize()
49
 
50
  fig = px.bar(
51
- df_filtered,
52
- x='Subfolder',
53
  y='Count',
54
- color='Indicator',
55
  barmode='group',
56
- labels={'Count': 'Occurrences', 'Subfolder': '', 'Indicator': ' <b>INDICATOR</b>'},
57
  color_discrete_sequence=px.colors.qualitative.D3
58
  )
59
-
60
  fig.update_traces(
61
  texttemplate='%{y}',
62
  textposition='inside',
63
- textfont=dict(color='rgb(255, 255, 255)'),
64
- insidetextanchor='middle'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
 
67
- fig.update_layout(
68
- xaxis=dict(showline=True),
69
- yaxis=dict(showticklabels=True, title='', tickformat=".0%" if chart_type == 'overall' else None),
70
- bargap=0.05,
71
- showlegend=(chart_type == 'individual')
72
- )
73
-
74
- return fig
75
-
76
- def causes_chart():
77
- data_file = os.path.join('data', 'indicator_cause_sentence_metadata.tsv')
78
- df = pd.read_csv(data_file, sep='\t')
79
-
80
- # Threshold
81
- min_value = 30
82
- df_filtered = df[df['cause'] != 'N/A'].copy()
83
- causes_meeting_threshold = df_filtered.groupby('cause')['cause'].count()[lambda x: x >= min_value].index
84
- df_filtered = df_filtered[df_filtered['cause'].isin(causes_meeting_threshold)]
85
- df_filtered['cause'] = df_filtered['cause'].str.capitalize()
86
-
87
- fig = px.bar(
88
- df_filtered.groupby(['subfolder', 'cause']).size().reset_index(name='Count'),
89
- x='subfolder',
90
- y='Count',
91
- color='cause',
92
- barmode='group',
93
- labels={'Count': 'Occurrences', 'subfolder': '', 'cause': '<b>CAUSE</b>'},
94
- color_discrete_sequence=px.colors.qualitative.D3,
95
- )
96
-
97
- fig.update_layout(
98
- xaxis=dict(showline=True),
99
- yaxis=dict(showticklabels=True, title=''),
100
-
101
- )
102
-
103
- fig.update_traces(
104
- texttemplate='%{y}',
105
- textposition='inside',
106
- textfont=dict(color='rgb(255, 255, 255)'),
107
- insidetextanchor='middle',
108
- )
109
-
110
- return fig
111
-
112
- def scatter(include_modality=False):
113
- data_file = os.path.join('data', 'feature_matrix.tsv')
114
- df = pd.read_csv(data_file, sep='\t')
115
-
116
- # Exclude sentences without any indicators, causes, or modalities (if included)
117
- indicator_columns = [col for col in df.columns if col.startswith('indicator_')]
118
- cause_columns = [col for col in df.columns if col.startswith('cause_')]
119
- modality_columns = [col for col in df.columns if col.startswith('modality_')]
120
-
121
- df_filtered = df[(df[indicator_columns].sum(axis=1) > 0) |
122
- (df[cause_columns].sum(axis=1) > 0)]
123
-
124
- # Exclude indicator '!besprechen'
125
- indicator_columns = [col for col in indicator_columns if 'indicator_!besprechen' not in col]
126
-
127
- # Limit indicators to those that occur at least 10 times
128
- indicator_counts = df_filtered[indicator_columns].sum()
129
- indicators_to_keep = indicator_counts[indicator_counts >= 10].index.tolist()
130
-
131
- # Further filter to exclude entries without any valid indicators
132
- df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0]
133
-
134
- # Exclude non-feature columns for dimensionality reduction
135
- columns_to_drop = ['subfolder']
136
- if not include_modality:
137
- columns_to_drop += modality_columns # Drop modality columns if not included
138
-
139
- features = df_filtered.drop(columns=columns_to_drop)
140
- features_clean = features.fillna(0)
141
-
142
- # Prepare metadata
143
- metadata = df_filtered[['subfolder']].copy()
144
- metadata['indicator'] = df_filtered[indicators_to_keep].apply(lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]), axis=1)
145
- metadata['cause'] = df_filtered[cause_columns].apply(lambda row: ', '.join([cause.replace('cause_', '') for cause in cause_columns if row[cause] > 0]), axis=1)
146
-
147
- # UMAP dimensionality reduction
148
- reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, n_jobs=1, metric='cosine')
149
- reduced_features = reducer.fit_transform(features_clean)
150
- df_reduced = pd.DataFrame(reduced_features, columns=['Component 1', 'Component 2'])
151
- df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1)
152
-
153
- # Plotting the scatter plot
154
- hover_data = {'cause': True, 'Component 1': False, 'Component 2': False}
155
- if include_modality:
156
- hover_data['Modality'] = True
157
-
158
- custom_labels = {
159
- 'subfolder': 'Effect', # Renaming 'subfolder' to 'Category'
160
- }
161
-
162
- fig = px.scatter(
163
- df_reduced,
164
- x='Component 1',
165
- y='Component 2',
166
- color='subfolder', # Only subfolder colors will show in the legend
167
- symbol='indicator', # Symbols for indicators, without showing in legend
168
- labels=custom_labels,
169
- hover_data=hover_data,
170
- color_discrete_sequence=px.colors.qualitative.D3
171
- )
172
-
173
- fig.update_layout(
174
- xaxis=dict(showgrid=True),
175
- yaxis=dict(showgrid=True),
176
- showlegend=True, # Show only the subfolder legend
177
- legend=dict(
178
- title="Effect, Indicator", # Adjust title to indicate the subfolder legend
179
- yanchor="top",
180
- xanchor="left",
181
- borderwidth=1,
182
- ),
183
- )
184
-
185
- return fig
186
-
187
- def sankey(cause_threshold=10, indicator_threshold=5):
188
- # Load the data
189
- data_file = os.path.join('data', 'indicator_cause_sentence_metadata.tsv')
190
- df = pd.read_csv(data_file, sep='\t')
191
-
192
- # Remove rows with NaN values in 'cause', 'indicator', or 'subfolder' columns
193
- df = df.dropna(subset=['cause', 'indicator', 'subfolder'])
194
-
195
- # Strip '_nk' from 'subfolder' values
196
- df['subfolder'] = df['subfolder'].str.replace('_nk', '')
197
-
198
- # Calculate overall counts for each cause and indicator
199
- cause_counts = df['cause'].value_counts()
200
- indicator_counts = df['indicator'].value_counts()
201
-
202
- # Filter causes and indicators that meet their respective thresholds
203
- valid_causes = cause_counts[cause_counts >= cause_threshold].index
204
- valid_indicators = indicator_counts[indicator_counts >= indicator_threshold].index
205
-
206
- # Filter the DataFrame to include only rows with causes and indicators that meet the thresholds
207
- df_filtered = df[(df['cause'].isin(valid_causes)) & (df['indicator'].isin(valid_indicators))]
208
-
209
- # Calculate pair counts for cause -> indicator and indicator -> subfolder
210
- cause_indicator_counts = df_filtered.groupby(['cause', 'indicator']).size().reset_index(name='count')
211
- indicator_subfolder_counts = df_filtered.groupby(['indicator', 'subfolder']).size().reset_index(name='count')
212
-
213
- # Generate unique labels for Sankey nodes, including all causes, indicators, and subfolders
214
- causes = df_filtered['cause'].unique()
215
- indicators = df_filtered['indicator'].unique()
216
- subfolders = df_filtered['subfolder'].unique()
217
- all_labels = list(causes) + list(indicators) + list(subfolders)
218
-
219
- # Mapping of each label to an index for Sankey node
220
- label_to_index = {label: idx for idx, label in enumerate(all_labels)}
221
-
222
- # Define sources, targets, and values for the Sankey diagram
223
- sources = []
224
- targets = []
225
- values = []
226
-
227
- # Add cause -> indicator links
228
- for _, row in cause_indicator_counts.iterrows():
229
- if row['cause'] in label_to_index and row['indicator'] in label_to_index:
230
- sources.append(label_to_index[row['cause']])
231
- targets.append(label_to_index[row['indicator']])
232
- values.append(row['count'])
233
-
234
- # Add indicator -> subfolder links
235
- for _, row in indicator_subfolder_counts.iterrows():
236
- if row['indicator'] in label_to_index and row['subfolder'] in label_to_index:
237
- sources.append(label_to_index[row['indicator']])
238
- targets.append(label_to_index[row['subfolder']])
239
- values.append(row['count'])
240
-
241
- fig = go.Figure(data=[go.Sankey(
242
- node=dict(
243
- pad=15,
244
- thickness=20,
245
- line=dict(color="black", width=0.5),
246
- label=all_labels,
247
- ),
248
- link=dict(
249
- source=sources,
250
- target=targets,
251
- value=values
252
  )
253
- )])
254
 
255
- fig.update_layout(
256
- autosize=False, # Disable automatic resizing
257
- width=500, # Fixed width
258
- height=500, # Fixed height
259
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
- return fig
 
3
  import plotly.graph_objects as go
4
  import os
5
  import umap
6
+ import streamlit as st
7
+
8
+ @st.cache_data
9
+ def load_data(file_path):
10
+ return pd.read_csv(file_path, sep='\t')
11
+
12
+ class Plot:
13
+ def __init__(self, data_file='data/feature_matrix.tsv', metadata_file='data/indicator_cause_sentence_metadata.tsv'):
14
+ self.data_file = data_file
15
+ self.metadata_file = metadata_file
16
+ self.df = load_data(self.data_file) # Cached data loading
17
+ self.metadata_df = load_data(self.metadata_file)
18
 
19
+ # Cache and compute necessary columns once
20
+ self.indicator_columns = [col for col in self.df.columns if col.startswith('indicator_')]
21
+ self.cause_columns = [col for col in self.df.columns if col.startswith('cause_')]
 
 
 
 
 
 
 
22
 
23
+ self.df['Year'] = self.df['text_date'].astype(str).str[:4]
24
+ self.df['Has_Indicator'] = self.df[self.indicator_columns].sum(axis=1) > 0
25
+
26
+ # Precompute totals for faster use in chart functions
27
+ self.total_sentences_per_year = self.df.groupby(['Year', 'subfolder']).size().reset_index(name='Total Sentences')
28
+ self.total_sentences_per_subfolder = self.df.groupby('subfolder').size().reset_index(name='Total Sentences')
29
+
30
+ def get_indicator_chart(self, chart_type='total', individual_threshold=5):
31
+ if chart_type == 'total':
32
+ # Summarize indicator share per subfolder
33
+ indicator_counts = self.df[self.df['Has_Indicator']].groupby('subfolder').size().reset_index(name='Indicator Count')
34
+ total_counts = indicator_counts.merge(self.total_sentences_per_subfolder, on='subfolder')
35
+ total_counts['Indicator_Share'] = total_counts['Indicator Count'] / total_counts['Total Sentences']
36
+ total_counts['Indicator_Share_Text'] = (total_counts['Indicator_Share'] * 100).round(2).astype(str) + '%'
37
+
38
+ fig = px.bar(
39
+ total_counts,
40
+ x='subfolder',
41
+ y='Indicator_Share',
42
+ labels={'Indicator_Share': 'Share of Sentences with Indicators', 'subfolder': ''},
43
+ color='subfolder',
44
+ text='Indicator_Share_Text',
45
+ color_discrete_sequence=px.colors.qualitative.D3
46
+ )
47
+ fig.update_traces(
48
+ textposition='inside',
49
+ texttemplate='%{text}',
50
+ textfont=dict(color='rgb(255, 255, 255)')
51
+ )
52
+
53
+ elif chart_type == 'individual':
54
+ # Melt the dataframe to long format
55
+ df_melted = self.df.melt(id_vars=['subfolder'], value_vars=self.indicator_columns, var_name='Indicator', value_name='Count')
56
+ df_melted = df_melted[df_melted['Count'] > 0]
57
+
58
+ # Group by Indicator only to calculate total counts across all subfolders
59
+ total_indicator_counts = df_melted.groupby('Indicator').size().reset_index(name='Total Count')
60
+ indicators_meeting_threshold = total_indicator_counts[total_indicator_counts['Total Count'] >= individual_threshold]['Indicator'].unique()
61
+
62
+ # Filter df_melted to include only indicators that meet the threshold overall
63
+ df_melted = df_melted[df_melted['Indicator'].isin(indicators_meeting_threshold)]
64
+ df_melted['Indicator'] = df_melted['Indicator'].str.replace('indicator_', '').str.capitalize()
65
+
66
+ # Re-aggregate counts by subfolder and indicator for the filtered indicators
67
+ df_melted = df_melted.groupby(['subfolder', 'Indicator']).size().reset_index(name='Count')
68
+
69
+ # Create the bar chart
70
+ fig = px.bar(
71
+ df_melted,
72
+ x='subfolder',
73
+ y='Count',
74
+ color='Indicator',
75
+ barmode='group',
76
+ labels={'Count': 'Occurrences', 'subfolder': '', 'Indicator': 'Indicator'},
77
+ color_discrete_sequence=px.colors.qualitative.D3
78
+ )
79
+ fig.update_traces(
80
+ texttemplate='%{y}',
81
+ textposition='inside',
82
+ textfont=dict(color='rgb(255, 255, 255)')
83
+ )
84
+
85
+ elif chart_type == 'year':
86
+ indicator_counts_per_year = self.df[self.df['Has_Indicator']].groupby(['Year', 'subfolder']).size().reset_index(name='Indicator Count')
87
+ df_summary = pd.merge(self.total_sentences_per_year, indicator_counts_per_year, on=['Year', 'subfolder'], how='left')
88
+ df_summary['Indicator_Share_Text'] = (df_summary['Indicator Count'] / df_summary['Total Sentences'] * 100).round(2).astype(str) + '%'
89
+
90
+ fig = px.bar(
91
+ df_summary,
92
+ x='Year',
93
+ y='Total Sentences',
94
+ color='subfolder',
95
+ labels={'Total Sentences': 'Total Number of Sentences', 'Year': 'Year'},
96
+ text='Indicator_Share_Text',
97
+ color_discrete_sequence=px.colors.qualitative.D3
98
+ )
99
+ fig.update_traces(
100
+ textposition='inside',
101
+ texttemplate='%{text}',
102
+ textfont=dict(color='rgb(255, 255, 255)')
103
+ )
104
+
105
+ fig.update_layout(
106
+ xaxis=dict(showline=True),
107
+ yaxis=dict(title='Indicator Sentences' if chart_type != 'year' else 'Total Sentences'),
108
+ bargap=0.05,
109
+ showlegend=(chart_type != 'total')
110
  )
111
+ return fig
112
 
113
+ def get_causes_chart(self, min_value=30):
114
+ df_filtered = self.metadata_df[self.metadata_df['cause'] != 'N/A']
115
+ causes_meeting_threshold = df_filtered.groupby('cause')['cause'].count()[lambda x: x >= min_value].index
116
+ df_filtered = df_filtered[df_filtered['cause'].isin(causes_meeting_threshold)]
117
+ df_filtered['cause'] = df_filtered['cause'].str.capitalize()
 
 
118
 
119
  fig = px.bar(
120
+ df_filtered.groupby(['subfolder', 'cause']).size().reset_index(name='Count'),
121
+ x='subfolder',
122
  y='Count',
123
+ color='cause',
124
  barmode='group',
125
+ labels={'Count': 'Occurrences', 'subfolder': '', 'cause': 'Cause'},
126
  color_discrete_sequence=px.colors.qualitative.D3
127
  )
128
+ fig.update_layout(xaxis=dict(showline=True), yaxis=dict(showticklabels=True, title=''))
129
  fig.update_traces(
130
  texttemplate='%{y}',
131
  textposition='inside',
132
+ textfont=dict(color='rgb(255, 255, 255)')
133
+ )
134
+ return fig
135
+
136
+ def scatter(self, include_modality=False):
137
+ # Use self.df to avoid reloading data
138
+ df_filtered = self.df[(self.df[self.indicator_columns].sum(axis=1) > 0) |
139
+ (self.df[self.cause_columns].sum(axis=1) > 0)]
140
+
141
+ # Exclude specific indicators and filter based on count threshold
142
+ indicator_columns = [col for col in self.indicator_columns if 'indicator_!besprechen' not in col]
143
+ indicator_counts = df_filtered[indicator_columns].sum()
144
+ indicators_to_keep = indicator_counts[indicator_counts >= 10].index.tolist()
145
+ df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0]
146
+
147
+ # Exclude non-feature columns for dimensionality reduction
148
+ columns_to_drop = ['subfolder', 'text_id', 'sentence_id', 'text_date', 'text_source', 'text_text_type']
149
+ if not include_modality:
150
+ columns_to_drop += [col for col in self.df.columns if col.startswith('modality_')]
151
+
152
+ features = df_filtered.drop(columns=columns_to_drop, errors='ignore').select_dtypes(include=[float, int])
153
+ features_clean = features.fillna(0)
154
+
155
+ # Prepare metadata for plotting
156
+ metadata = df_filtered[['subfolder']].copy()
157
+ metadata['indicator'] = df_filtered[indicators_to_keep].apply(
158
+ lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]),
159
+ axis=1
160
+ )
161
+ metadata['cause'] = df_filtered[self.cause_columns].apply(
162
+ lambda row: ', '.join([cause.replace('cause_', '') for cause in self.cause_columns if row[cause] > 0]),
163
+ axis=1
164
+ )
165
+
166
+ # Perform UMAP dimensionality reduction
167
+ reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, n_jobs=1, metric='cosine')
168
+ reduced_features = reducer.fit_transform(features_clean)
169
+ df_reduced = pd.DataFrame(reduced_features, columns=['Component 1', 'Component 2'])
170
+ df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1)
171
+
172
+ # Plotting the scatter plot
173
+ hover_data = {'cause': True, 'Component 1': False, 'Component 2': False}
174
+ if include_modality:
175
+ hover_data['Modality'] = True
176
+
177
+ fig = px.scatter(
178
+ df_reduced,
179
+ x='Component 1',
180
+ y='Component 2',
181
+ color='subfolder',
182
+ symbol='indicator',
183
+ labels={'subfolder': 'Effect'},
184
+ hover_data=hover_data,
185
+ color_discrete_sequence=px.colors.qualitative.D3
186
  )
187
 
188
+ fig.update_layout(
189
+ xaxis=dict(showgrid=True),
190
+ yaxis=dict(showgrid=True),
191
+ showlegend=True,
192
+ legend=dict(title="Effect, Indicator", yanchor="top", xanchor="left", borderwidth=1),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  )
 
194
 
195
+ return fig
196
+
197
+ def sankey(self, cause_threshold=10, indicator_threshold=5, link_opacity=0.4):
198
+ # Use self.df to avoid reloading data
199
+ df_filtered = self.df[(self.df[self.cause_columns].sum(axis=1) > 0) &
200
+ (self.df[self.indicator_columns].sum(axis=1) > 0)]
201
+
202
+ # Melt causes and indicators separately, ensuring unique sentence IDs
203
+ cause_data = df_filtered[['text_id', 'subfolder'] + self.cause_columns].melt(
204
+ id_vars=['text_id', 'subfolder'], var_name='cause', value_name='count'
205
+ ).query("count > 0").drop_duplicates(['text_id', 'cause'])
206
+
207
+ indicator_data = df_filtered[['text_id', 'subfolder'] + self.indicator_columns].melt(
208
+ id_vars=['text_id', 'subfolder'], var_name='indicator', value_name='count'
209
+ ).query("count > 0").drop_duplicates(['text_id', 'indicator'])
210
+
211
+ # Apply threshold filters
212
+ valid_causes = cause_data['cause'].value_counts()[lambda x: x >= cause_threshold].index
213
+ valid_indicators = indicator_data['indicator'].value_counts()[lambda x: x >= indicator_threshold].index
214
+ cause_data = cause_data[cause_data['cause'].isin(valid_causes)]
215
+ indicator_data = indicator_data[indicator_data['indicator'].isin(valid_indicators)]
216
+
217
+ # Create unique cause-indicator-subfolder links by merging cause and indicator data on 'text_id' and 'subfolder'
218
+ cause_indicator_links = (
219
+ cause_data.merge(indicator_data, on=['text_id', 'subfolder'])
220
+ .groupby(['cause', 'indicator']).size().reset_index(name='count')
221
+ )
222
+
223
+ # Aggregate indicator-subfolder counts
224
+ indicator_subfolder_links = (
225
+ indicator_data.groupby(['indicator', 'subfolder']).size().reset_index(name='count')
226
+ )
227
+
228
+ # Define unique labels and their order
229
+ all_labels = list(valid_causes) + list(valid_indicators) + self.df['subfolder'].unique().tolist()
230
+
231
+ # Remove prefixes for cleaner labels
232
+ all_labels_cleaned = [label.replace("cause_", "").replace("indicator_", "") for label in all_labels]
233
+ label_to_index = {label: idx for idx, label in enumerate(all_labels)}
234
+
235
+ # Define a color palette from Plotly's D3 color sequence
236
+ color_palette = px.colors.qualitative.D3
237
+ node_colors = [color_palette[i % len(color_palette)] for i in range(len(all_labels))]
238
+
239
+ # Define sources, targets, values, and link colors with RGBA opacity
240
+ sources, targets, values, link_colors = [], [], [], []
241
+
242
+ def hex_to_rgba(hex_color, opacity):
243
+ return f'rgba({int(hex_color[1:3], 16)}, {int(hex_color[3:5], 16)}, {int(hex_color[5:], 16)}, {opacity})'
244
+
245
+ # Cause -> Indicator links
246
+ for _, row in cause_indicator_links.iterrows():
247
+ if row['cause'] in label_to_index and row['indicator'] in label_to_index:
248
+ source_idx = label_to_index[row['cause']]
249
+ target_idx = label_to_index[row['indicator']]
250
+ sources.append(source_idx)
251
+ targets.append(target_idx)
252
+ values.append(row['count'])
253
+ link_colors.append(hex_to_rgba(node_colors[source_idx], link_opacity))
254
+
255
+ # Indicator -> Subfolder links
256
+ for _, row in indicator_subfolder_links.iterrows():
257
+ if row['indicator'] in label_to_index and row['subfolder'] in label_to_index:
258
+ source_idx = label_to_index[row['indicator']]
259
+ target_idx = label_to_index[row['subfolder']]
260
+ sources.append(source_idx)
261
+ targets.append(target_idx)
262
+ values.append(row['count'])
263
+ link_colors.append(hex_to_rgba(node_colors[source_idx], link_opacity))
264
+
265
+ fig = go.Figure(data=[go.Sankey(
266
+ node=dict(
267
+ pad=15,
268
+ thickness=20,
269
+ line=dict(color="black", width=0.5),
270
+ label=all_labels_cleaned,
271
+ color=node_colors
272
+ ),
273
+ link=dict(
274
+ source=sources,
275
+ target=targets,
276
+ value=values,
277
+ color=link_colors
278
+ )
279
+ )])
280
+
281
+ fig.update_layout(
282
+ autosize=False,
283
+ width=800,
284
+ height=600,
285
+ font=dict(size=10)
286
+ )
287
 
288
+ return fig