Added plot.py
Browse files- app.py +65 -29
- data/feature_matrix.tsv +0 -0
- plot.py +265 -238
app.py
CHANGED
@@ -2,14 +2,17 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
4 |
from annotated_text import annotated_text
|
5 |
-
import pandas as pd
|
6 |
-
import plotly.express as px
|
7 |
-
from plot import indicator_chart, causes_chart, scatter, sankey
|
8 |
import os
|
|
|
9 |
|
10 |
# Define initial threshold values at the top of the script
|
11 |
default_cause_threshold = 20
|
12 |
default_indicator_threshold = 15
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Load the trained model and tokenizer
|
15 |
model_directory = "norygano/causalBERT"
|
@@ -30,16 +33,16 @@ st.markdown(
|
|
30 |
""",
|
31 |
unsafe_allow_html=True
|
32 |
)
|
33 |
-
st.markdown("[
|
34 |
-
st.write("
|
35 |
|
36 |
# Create tabs
|
37 |
-
tab1, tab2, tab3, tab4, tab5 = st.tabs(["
|
38 |
|
39 |
# Prompt Tab
|
40 |
with tab1:
|
41 |
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
|
42 |
-
"Autos stehen im Verdacht, Waldsterben zu
|
43 |
"Fußball führt zu Waldschäden.",
|
44 |
"Haustüren tragen zum Betonsterben bei.",
|
45 |
]), placeholder="German only (currently)")
|
@@ -82,36 +85,69 @@ with tab1:
|
|
82 |
annotated_text(*annotations)
|
83 |
st.write("---")
|
84 |
|
85 |
-
|
|
|
86 |
with tab2:
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
with tab3:
|
98 |
-
|
99 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
|
|
101 |
with tab4:
|
102 |
-
fig_scatter = scatter()
|
103 |
st.plotly_chart(fig_scatter, use_container_width=True)
|
104 |
|
|
|
105 |
with tab5:
|
106 |
-
# Fixed height for the Sankey chart container
|
107 |
with st.container():
|
108 |
-
#
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
112 |
st.plotly_chart(fig_sankey, use_container_width=True)
|
113 |
-
|
114 |
-
# Place sliders below the chart container
|
115 |
with st.container():
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
4 |
from annotated_text import annotated_text
|
|
|
|
|
|
|
5 |
import os
|
6 |
+
from plot import Plot # Assuming the class is saved in diagram_generator.py
|
7 |
|
8 |
# Define initial threshold values at the top of the script
|
9 |
default_cause_threshold = 20
|
10 |
default_indicator_threshold = 15
|
11 |
+
default_cause_threshold_sankey = 20
|
12 |
+
default_indicator_threshold_sankey = 15
|
13 |
+
|
14 |
+
# Initialize Plots
|
15 |
+
plot = Plot()
|
16 |
|
17 |
# Load the trained model and tokenizer
|
18 |
model_directory = "norygano/causalBERT"
|
|
|
33 |
""",
|
34 |
unsafe_allow_html=True
|
35 |
)
|
36 |
+
st.markdown("[Weights](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv) | [Project](https://www.uni-trier.de/universitaet/fachbereiche-faecher/fachbereich-ii/faecher/germanistik/professurenfachteile/germanistische-linguistik/professoren/prof-dr-martin-wengeler/kontroverse-diskurse/individium-gesellschaft)")
|
37 |
+
st.write("Indicators and causes in explicit attributions of causality.")
|
38 |
|
39 |
# Create tabs
|
40 |
+
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Model", "Indicators", "Causes", "Scatter", "Sankey"])
|
41 |
|
42 |
# Prompt Tab
|
43 |
with tab1:
|
44 |
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
|
45 |
+
"Autos stehen im Verdacht, Waldsterben verursacht zu haben.",
|
46 |
"Fußball führt zu Waldschäden.",
|
47 |
"Haustüren tragen zum Betonsterben bei.",
|
48 |
]), placeholder="German only (currently)")
|
|
|
85 |
annotated_text(*annotations)
|
86 |
st.write("---")
|
87 |
|
88 |
+
|
89 |
+
# Indicator Tab
|
90 |
with tab2:
|
91 |
+
selected_chart_type = st.radio(
|
92 |
+
"Type",
|
93 |
+
options=['Total', 'Year', 'Individual'],
|
94 |
+
horizontal=True,
|
95 |
+
)
|
96 |
+
|
97 |
+
# Display the chart in a container
|
98 |
+
with st.container():
|
99 |
+
if selected_chart_type == 'Individual':
|
100 |
+
# Retrieve slider value from session state or use default
|
101 |
+
individual_threshold = st.session_state.get("individual_threshold", default_indicator_threshold)
|
102 |
+
fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower(), individual_threshold=individual_threshold)
|
103 |
+
else:
|
104 |
+
fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower())
|
105 |
+
st.plotly_chart(fig, use_container_width=True)
|
106 |
|
107 |
+
# Display the slider below the chart container for 'Individual' type
|
108 |
+
if selected_chart_type == 'Individual':
|
109 |
+
with st.container():
|
110 |
+
individual_threshold = st.slider(
|
111 |
+
"Indicator >=",
|
112 |
+
min_value=1,
|
113 |
+
max_value=95,
|
114 |
+
value=default_indicator_threshold,
|
115 |
+
key="individual_threshold"
|
116 |
+
)
|
117 |
+
|
118 |
+
# Causes Tab
|
119 |
with tab3:
|
120 |
+
# Create a container for the chart and place the slider below it
|
121 |
+
with st.container():
|
122 |
+
# Display the chart first
|
123 |
+
fig_causes = plot.get_causes_chart(min_value=st.session_state.get("cause_threshold_causes", default_cause_threshold))
|
124 |
+
st.plotly_chart(fig_causes, use_container_width=True)
|
125 |
+
|
126 |
+
# Place the slider below the chart with a unique key
|
127 |
+
cause_threshold_causes = st.slider(
|
128 |
+
"Cause >=", min_value=1, max_value=75, value=default_cause_threshold, key="cause_threshold_causes"
|
129 |
+
)
|
130 |
|
131 |
+
# Scatter Tab
|
132 |
with tab4:
|
133 |
+
fig_scatter = plot.scatter()
|
134 |
st.plotly_chart(fig_scatter, use_container_width=True)
|
135 |
|
136 |
+
# Sankey Tab
|
137 |
with tab5:
|
|
|
138 |
with st.container():
|
139 |
+
# Use the unique Sankey threshold variables in session state
|
140 |
+
cause_threshold_sankey = st.session_state.get("cause_threshold_sankey", default_cause_threshold_sankey)
|
141 |
+
indicator_threshold_sankey = st.session_state.get("indicator_threshold_sankey", default_indicator_threshold_sankey)
|
142 |
+
|
143 |
+
# Generate the Sankey diagram with the new Sankey-specific thresholds
|
144 |
+
fig_sankey = plot.sankey(cause_threshold=cause_threshold_sankey, indicator_threshold=indicator_threshold_sankey)
|
145 |
st.plotly_chart(fig_sankey, use_container_width=True)
|
146 |
+
# Place sliders below the chart container with unique keys for the Sankey tab
|
|
|
147 |
with st.container():
|
148 |
+
cause_threshold_sankey = st.slider(
|
149 |
+
"Cause >=", min_value=1, max_value=100, value=default_cause_threshold_sankey, key="cause_threshold_sankey"
|
150 |
+
)
|
151 |
+
indicator_threshold_sankey = st.slider(
|
152 |
+
"Indicator >=", min_value=1, max_value=100, value=default_indicator_threshold_sankey, key="indicator_threshold_sankey"
|
153 |
+
)
|
data/feature_matrix.tsv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
plot.py
CHANGED
@@ -3,259 +3,286 @@ import plotly.express as px
|
|
3 |
import plotly.graph_objects as go
|
4 |
import os
|
5 |
import umap
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
y='Indicator_Share',
|
22 |
-
labels={'Indicator_Share': 'Share of Sentences with Indicators', 'Subfolder': ''},
|
23 |
-
color='Subfolder',
|
24 |
-
text='Indicator_Share_Text',
|
25 |
-
color_discrete_sequence=px.colors.qualitative.D3,
|
26 |
-
custom_data=['Total Sentences', 'Count']
|
27 |
-
)
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
)
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
df_filtered =
|
46 |
-
|
47 |
-
df_filtered = df_filtered[df_filtered['Indicator'].isin(indicators_meeting_threshold)]
|
48 |
-
df_filtered['Indicator'] = df_filtered['Indicator'].str.capitalize()
|
49 |
|
50 |
fig = px.bar(
|
51 |
-
df_filtered,
|
52 |
-
x='
|
53 |
y='Count',
|
54 |
-
color='
|
55 |
barmode='group',
|
56 |
-
labels={'Count': 'Occurrences', '
|
57 |
color_discrete_sequence=px.colors.qualitative.D3
|
58 |
)
|
59 |
-
|
60 |
fig.update_traces(
|
61 |
texttemplate='%{y}',
|
62 |
textposition='inside',
|
63 |
-
textfont=dict(color='rgb(255, 255, 255)')
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
)
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
)
|
73 |
-
|
74 |
-
return fig
|
75 |
-
|
76 |
-
def causes_chart():
|
77 |
-
data_file = os.path.join('data', 'indicator_cause_sentence_metadata.tsv')
|
78 |
-
df = pd.read_csv(data_file, sep='\t')
|
79 |
-
|
80 |
-
# Threshold
|
81 |
-
min_value = 30
|
82 |
-
df_filtered = df[df['cause'] != 'N/A'].copy()
|
83 |
-
causes_meeting_threshold = df_filtered.groupby('cause')['cause'].count()[lambda x: x >= min_value].index
|
84 |
-
df_filtered = df_filtered[df_filtered['cause'].isin(causes_meeting_threshold)]
|
85 |
-
df_filtered['cause'] = df_filtered['cause'].str.capitalize()
|
86 |
-
|
87 |
-
fig = px.bar(
|
88 |
-
df_filtered.groupby(['subfolder', 'cause']).size().reset_index(name='Count'),
|
89 |
-
x='subfolder',
|
90 |
-
y='Count',
|
91 |
-
color='cause',
|
92 |
-
barmode='group',
|
93 |
-
labels={'Count': 'Occurrences', 'subfolder': '', 'cause': '<b>CAUSE</b>'},
|
94 |
-
color_discrete_sequence=px.colors.qualitative.D3,
|
95 |
-
)
|
96 |
-
|
97 |
-
fig.update_layout(
|
98 |
-
xaxis=dict(showline=True),
|
99 |
-
yaxis=dict(showticklabels=True, title=''),
|
100 |
-
|
101 |
-
)
|
102 |
-
|
103 |
-
fig.update_traces(
|
104 |
-
texttemplate='%{y}',
|
105 |
-
textposition='inside',
|
106 |
-
textfont=dict(color='rgb(255, 255, 255)'),
|
107 |
-
insidetextanchor='middle',
|
108 |
-
)
|
109 |
-
|
110 |
-
return fig
|
111 |
-
|
112 |
-
def scatter(include_modality=False):
|
113 |
-
data_file = os.path.join('data', 'feature_matrix.tsv')
|
114 |
-
df = pd.read_csv(data_file, sep='\t')
|
115 |
-
|
116 |
-
# Exclude sentences without any indicators, causes, or modalities (if included)
|
117 |
-
indicator_columns = [col for col in df.columns if col.startswith('indicator_')]
|
118 |
-
cause_columns = [col for col in df.columns if col.startswith('cause_')]
|
119 |
-
modality_columns = [col for col in df.columns if col.startswith('modality_')]
|
120 |
-
|
121 |
-
df_filtered = df[(df[indicator_columns].sum(axis=1) > 0) |
|
122 |
-
(df[cause_columns].sum(axis=1) > 0)]
|
123 |
-
|
124 |
-
# Exclude indicator '!besprechen'
|
125 |
-
indicator_columns = [col for col in indicator_columns if 'indicator_!besprechen' not in col]
|
126 |
-
|
127 |
-
# Limit indicators to those that occur at least 10 times
|
128 |
-
indicator_counts = df_filtered[indicator_columns].sum()
|
129 |
-
indicators_to_keep = indicator_counts[indicator_counts >= 10].index.tolist()
|
130 |
-
|
131 |
-
# Further filter to exclude entries without any valid indicators
|
132 |
-
df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0]
|
133 |
-
|
134 |
-
# Exclude non-feature columns for dimensionality reduction
|
135 |
-
columns_to_drop = ['subfolder']
|
136 |
-
if not include_modality:
|
137 |
-
columns_to_drop += modality_columns # Drop modality columns if not included
|
138 |
-
|
139 |
-
features = df_filtered.drop(columns=columns_to_drop)
|
140 |
-
features_clean = features.fillna(0)
|
141 |
-
|
142 |
-
# Prepare metadata
|
143 |
-
metadata = df_filtered[['subfolder']].copy()
|
144 |
-
metadata['indicator'] = df_filtered[indicators_to_keep].apply(lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]), axis=1)
|
145 |
-
metadata['cause'] = df_filtered[cause_columns].apply(lambda row: ', '.join([cause.replace('cause_', '') for cause in cause_columns if row[cause] > 0]), axis=1)
|
146 |
-
|
147 |
-
# UMAP dimensionality reduction
|
148 |
-
reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, n_jobs=1, metric='cosine')
|
149 |
-
reduced_features = reducer.fit_transform(features_clean)
|
150 |
-
df_reduced = pd.DataFrame(reduced_features, columns=['Component 1', 'Component 2'])
|
151 |
-
df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1)
|
152 |
-
|
153 |
-
# Plotting the scatter plot
|
154 |
-
hover_data = {'cause': True, 'Component 1': False, 'Component 2': False}
|
155 |
-
if include_modality:
|
156 |
-
hover_data['Modality'] = True
|
157 |
-
|
158 |
-
custom_labels = {
|
159 |
-
'subfolder': 'Effect', # Renaming 'subfolder' to 'Category'
|
160 |
-
}
|
161 |
-
|
162 |
-
fig = px.scatter(
|
163 |
-
df_reduced,
|
164 |
-
x='Component 1',
|
165 |
-
y='Component 2',
|
166 |
-
color='subfolder', # Only subfolder colors will show in the legend
|
167 |
-
symbol='indicator', # Symbols for indicators, without showing in legend
|
168 |
-
labels=custom_labels,
|
169 |
-
hover_data=hover_data,
|
170 |
-
color_discrete_sequence=px.colors.qualitative.D3
|
171 |
-
)
|
172 |
-
|
173 |
-
fig.update_layout(
|
174 |
-
xaxis=dict(showgrid=True),
|
175 |
-
yaxis=dict(showgrid=True),
|
176 |
-
showlegend=True, # Show only the subfolder legend
|
177 |
-
legend=dict(
|
178 |
-
title="Effect, Indicator", # Adjust title to indicate the subfolder legend
|
179 |
-
yanchor="top",
|
180 |
-
xanchor="left",
|
181 |
-
borderwidth=1,
|
182 |
-
),
|
183 |
-
)
|
184 |
-
|
185 |
-
return fig
|
186 |
-
|
187 |
-
def sankey(cause_threshold=10, indicator_threshold=5):
|
188 |
-
# Load the data
|
189 |
-
data_file = os.path.join('data', 'indicator_cause_sentence_metadata.tsv')
|
190 |
-
df = pd.read_csv(data_file, sep='\t')
|
191 |
-
|
192 |
-
# Remove rows with NaN values in 'cause', 'indicator', or 'subfolder' columns
|
193 |
-
df = df.dropna(subset=['cause', 'indicator', 'subfolder'])
|
194 |
-
|
195 |
-
# Strip '_nk' from 'subfolder' values
|
196 |
-
df['subfolder'] = df['subfolder'].str.replace('_nk', '')
|
197 |
-
|
198 |
-
# Calculate overall counts for each cause and indicator
|
199 |
-
cause_counts = df['cause'].value_counts()
|
200 |
-
indicator_counts = df['indicator'].value_counts()
|
201 |
-
|
202 |
-
# Filter causes and indicators that meet their respective thresholds
|
203 |
-
valid_causes = cause_counts[cause_counts >= cause_threshold].index
|
204 |
-
valid_indicators = indicator_counts[indicator_counts >= indicator_threshold].index
|
205 |
-
|
206 |
-
# Filter the DataFrame to include only rows with causes and indicators that meet the thresholds
|
207 |
-
df_filtered = df[(df['cause'].isin(valid_causes)) & (df['indicator'].isin(valid_indicators))]
|
208 |
-
|
209 |
-
# Calculate pair counts for cause -> indicator and indicator -> subfolder
|
210 |
-
cause_indicator_counts = df_filtered.groupby(['cause', 'indicator']).size().reset_index(name='count')
|
211 |
-
indicator_subfolder_counts = df_filtered.groupby(['indicator', 'subfolder']).size().reset_index(name='count')
|
212 |
-
|
213 |
-
# Generate unique labels for Sankey nodes, including all causes, indicators, and subfolders
|
214 |
-
causes = df_filtered['cause'].unique()
|
215 |
-
indicators = df_filtered['indicator'].unique()
|
216 |
-
subfolders = df_filtered['subfolder'].unique()
|
217 |
-
all_labels = list(causes) + list(indicators) + list(subfolders)
|
218 |
-
|
219 |
-
# Mapping of each label to an index for Sankey node
|
220 |
-
label_to_index = {label: idx for idx, label in enumerate(all_labels)}
|
221 |
-
|
222 |
-
# Define sources, targets, and values for the Sankey diagram
|
223 |
-
sources = []
|
224 |
-
targets = []
|
225 |
-
values = []
|
226 |
-
|
227 |
-
# Add cause -> indicator links
|
228 |
-
for _, row in cause_indicator_counts.iterrows():
|
229 |
-
if row['cause'] in label_to_index and row['indicator'] in label_to_index:
|
230 |
-
sources.append(label_to_index[row['cause']])
|
231 |
-
targets.append(label_to_index[row['indicator']])
|
232 |
-
values.append(row['count'])
|
233 |
-
|
234 |
-
# Add indicator -> subfolder links
|
235 |
-
for _, row in indicator_subfolder_counts.iterrows():
|
236 |
-
if row['indicator'] in label_to_index and row['subfolder'] in label_to_index:
|
237 |
-
sources.append(label_to_index[row['indicator']])
|
238 |
-
targets.append(label_to_index[row['subfolder']])
|
239 |
-
values.append(row['count'])
|
240 |
-
|
241 |
-
fig = go.Figure(data=[go.Sankey(
|
242 |
-
node=dict(
|
243 |
-
pad=15,
|
244 |
-
thickness=20,
|
245 |
-
line=dict(color="black", width=0.5),
|
246 |
-
label=all_labels,
|
247 |
-
),
|
248 |
-
link=dict(
|
249 |
-
source=sources,
|
250 |
-
target=targets,
|
251 |
-
value=values
|
252 |
)
|
253 |
-
)])
|
254 |
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
-
|
|
|
3 |
import plotly.graph_objects as go
|
4 |
import os
|
5 |
import umap
|
6 |
+
import streamlit as st
|
7 |
+
|
8 |
+
@st.cache_data
|
9 |
+
def load_data(file_path):
|
10 |
+
return pd.read_csv(file_path, sep='\t')
|
11 |
+
|
12 |
+
class Plot:
|
13 |
+
def __init__(self, data_file='data/feature_matrix.tsv', metadata_file='data/indicator_cause_sentence_metadata.tsv'):
|
14 |
+
self.data_file = data_file
|
15 |
+
self.metadata_file = metadata_file
|
16 |
+
self.df = load_data(self.data_file) # Cached data loading
|
17 |
+
self.metadata_df = load_data(self.metadata_file)
|
18 |
|
19 |
+
# Cache and compute necessary columns once
|
20 |
+
self.indicator_columns = [col for col in self.df.columns if col.startswith('indicator_')]
|
21 |
+
self.cause_columns = [col for col in self.df.columns if col.startswith('cause_')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
self.df['Year'] = self.df['text_date'].astype(str).str[:4]
|
24 |
+
self.df['Has_Indicator'] = self.df[self.indicator_columns].sum(axis=1) > 0
|
25 |
+
|
26 |
+
# Precompute totals for faster use in chart functions
|
27 |
+
self.total_sentences_per_year = self.df.groupby(['Year', 'subfolder']).size().reset_index(name='Total Sentences')
|
28 |
+
self.total_sentences_per_subfolder = self.df.groupby('subfolder').size().reset_index(name='Total Sentences')
|
29 |
+
|
30 |
+
def get_indicator_chart(self, chart_type='total', individual_threshold=5):
|
31 |
+
if chart_type == 'total':
|
32 |
+
# Summarize indicator share per subfolder
|
33 |
+
indicator_counts = self.df[self.df['Has_Indicator']].groupby('subfolder').size().reset_index(name='Indicator Count')
|
34 |
+
total_counts = indicator_counts.merge(self.total_sentences_per_subfolder, on='subfolder')
|
35 |
+
total_counts['Indicator_Share'] = total_counts['Indicator Count'] / total_counts['Total Sentences']
|
36 |
+
total_counts['Indicator_Share_Text'] = (total_counts['Indicator_Share'] * 100).round(2).astype(str) + '%'
|
37 |
+
|
38 |
+
fig = px.bar(
|
39 |
+
total_counts,
|
40 |
+
x='subfolder',
|
41 |
+
y='Indicator_Share',
|
42 |
+
labels={'Indicator_Share': 'Share of Sentences with Indicators', 'subfolder': ''},
|
43 |
+
color='subfolder',
|
44 |
+
text='Indicator_Share_Text',
|
45 |
+
color_discrete_sequence=px.colors.qualitative.D3
|
46 |
+
)
|
47 |
+
fig.update_traces(
|
48 |
+
textposition='inside',
|
49 |
+
texttemplate='%{text}',
|
50 |
+
textfont=dict(color='rgb(255, 255, 255)')
|
51 |
+
)
|
52 |
+
|
53 |
+
elif chart_type == 'individual':
|
54 |
+
# Melt the dataframe to long format
|
55 |
+
df_melted = self.df.melt(id_vars=['subfolder'], value_vars=self.indicator_columns, var_name='Indicator', value_name='Count')
|
56 |
+
df_melted = df_melted[df_melted['Count'] > 0]
|
57 |
+
|
58 |
+
# Group by Indicator only to calculate total counts across all subfolders
|
59 |
+
total_indicator_counts = df_melted.groupby('Indicator').size().reset_index(name='Total Count')
|
60 |
+
indicators_meeting_threshold = total_indicator_counts[total_indicator_counts['Total Count'] >= individual_threshold]['Indicator'].unique()
|
61 |
+
|
62 |
+
# Filter df_melted to include only indicators that meet the threshold overall
|
63 |
+
df_melted = df_melted[df_melted['Indicator'].isin(indicators_meeting_threshold)]
|
64 |
+
df_melted['Indicator'] = df_melted['Indicator'].str.replace('indicator_', '').str.capitalize()
|
65 |
+
|
66 |
+
# Re-aggregate counts by subfolder and indicator for the filtered indicators
|
67 |
+
df_melted = df_melted.groupby(['subfolder', 'Indicator']).size().reset_index(name='Count')
|
68 |
+
|
69 |
+
# Create the bar chart
|
70 |
+
fig = px.bar(
|
71 |
+
df_melted,
|
72 |
+
x='subfolder',
|
73 |
+
y='Count',
|
74 |
+
color='Indicator',
|
75 |
+
barmode='group',
|
76 |
+
labels={'Count': 'Occurrences', 'subfolder': '', 'Indicator': 'Indicator'},
|
77 |
+
color_discrete_sequence=px.colors.qualitative.D3
|
78 |
+
)
|
79 |
+
fig.update_traces(
|
80 |
+
texttemplate='%{y}',
|
81 |
+
textposition='inside',
|
82 |
+
textfont=dict(color='rgb(255, 255, 255)')
|
83 |
+
)
|
84 |
+
|
85 |
+
elif chart_type == 'year':
|
86 |
+
indicator_counts_per_year = self.df[self.df['Has_Indicator']].groupby(['Year', 'subfolder']).size().reset_index(name='Indicator Count')
|
87 |
+
df_summary = pd.merge(self.total_sentences_per_year, indicator_counts_per_year, on=['Year', 'subfolder'], how='left')
|
88 |
+
df_summary['Indicator_Share_Text'] = (df_summary['Indicator Count'] / df_summary['Total Sentences'] * 100).round(2).astype(str) + '%'
|
89 |
+
|
90 |
+
fig = px.bar(
|
91 |
+
df_summary,
|
92 |
+
x='Year',
|
93 |
+
y='Total Sentences',
|
94 |
+
color='subfolder',
|
95 |
+
labels={'Total Sentences': 'Total Number of Sentences', 'Year': 'Year'},
|
96 |
+
text='Indicator_Share_Text',
|
97 |
+
color_discrete_sequence=px.colors.qualitative.D3
|
98 |
+
)
|
99 |
+
fig.update_traces(
|
100 |
+
textposition='inside',
|
101 |
+
texttemplate='%{text}',
|
102 |
+
textfont=dict(color='rgb(255, 255, 255)')
|
103 |
+
)
|
104 |
+
|
105 |
+
fig.update_layout(
|
106 |
+
xaxis=dict(showline=True),
|
107 |
+
yaxis=dict(title='Indicator Sentences' if chart_type != 'year' else 'Total Sentences'),
|
108 |
+
bargap=0.05,
|
109 |
+
showlegend=(chart_type != 'total')
|
110 |
)
|
111 |
+
return fig
|
112 |
|
113 |
+
def get_causes_chart(self, min_value=30):
|
114 |
+
df_filtered = self.metadata_df[self.metadata_df['cause'] != 'N/A']
|
115 |
+
causes_meeting_threshold = df_filtered.groupby('cause')['cause'].count()[lambda x: x >= min_value].index
|
116 |
+
df_filtered = df_filtered[df_filtered['cause'].isin(causes_meeting_threshold)]
|
117 |
+
df_filtered['cause'] = df_filtered['cause'].str.capitalize()
|
|
|
|
|
118 |
|
119 |
fig = px.bar(
|
120 |
+
df_filtered.groupby(['subfolder', 'cause']).size().reset_index(name='Count'),
|
121 |
+
x='subfolder',
|
122 |
y='Count',
|
123 |
+
color='cause',
|
124 |
barmode='group',
|
125 |
+
labels={'Count': 'Occurrences', 'subfolder': '', 'cause': 'Cause'},
|
126 |
color_discrete_sequence=px.colors.qualitative.D3
|
127 |
)
|
128 |
+
fig.update_layout(xaxis=dict(showline=True), yaxis=dict(showticklabels=True, title=''))
|
129 |
fig.update_traces(
|
130 |
texttemplate='%{y}',
|
131 |
textposition='inside',
|
132 |
+
textfont=dict(color='rgb(255, 255, 255)')
|
133 |
+
)
|
134 |
+
return fig
|
135 |
+
|
136 |
+
def scatter(self, include_modality=False):
|
137 |
+
# Use self.df to avoid reloading data
|
138 |
+
df_filtered = self.df[(self.df[self.indicator_columns].sum(axis=1) > 0) |
|
139 |
+
(self.df[self.cause_columns].sum(axis=1) > 0)]
|
140 |
+
|
141 |
+
# Exclude specific indicators and filter based on count threshold
|
142 |
+
indicator_columns = [col for col in self.indicator_columns if 'indicator_!besprechen' not in col]
|
143 |
+
indicator_counts = df_filtered[indicator_columns].sum()
|
144 |
+
indicators_to_keep = indicator_counts[indicator_counts >= 10].index.tolist()
|
145 |
+
df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0]
|
146 |
+
|
147 |
+
# Exclude non-feature columns for dimensionality reduction
|
148 |
+
columns_to_drop = ['subfolder', 'text_id', 'sentence_id', 'text_date', 'text_source', 'text_text_type']
|
149 |
+
if not include_modality:
|
150 |
+
columns_to_drop += [col for col in self.df.columns if col.startswith('modality_')]
|
151 |
+
|
152 |
+
features = df_filtered.drop(columns=columns_to_drop, errors='ignore').select_dtypes(include=[float, int])
|
153 |
+
features_clean = features.fillna(0)
|
154 |
+
|
155 |
+
# Prepare metadata for plotting
|
156 |
+
metadata = df_filtered[['subfolder']].copy()
|
157 |
+
metadata['indicator'] = df_filtered[indicators_to_keep].apply(
|
158 |
+
lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]),
|
159 |
+
axis=1
|
160 |
+
)
|
161 |
+
metadata['cause'] = df_filtered[self.cause_columns].apply(
|
162 |
+
lambda row: ', '.join([cause.replace('cause_', '') for cause in self.cause_columns if row[cause] > 0]),
|
163 |
+
axis=1
|
164 |
+
)
|
165 |
+
|
166 |
+
# Perform UMAP dimensionality reduction
|
167 |
+
reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, n_jobs=1, metric='cosine')
|
168 |
+
reduced_features = reducer.fit_transform(features_clean)
|
169 |
+
df_reduced = pd.DataFrame(reduced_features, columns=['Component 1', 'Component 2'])
|
170 |
+
df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1)
|
171 |
+
|
172 |
+
# Plotting the scatter plot
|
173 |
+
hover_data = {'cause': True, 'Component 1': False, 'Component 2': False}
|
174 |
+
if include_modality:
|
175 |
+
hover_data['Modality'] = True
|
176 |
+
|
177 |
+
fig = px.scatter(
|
178 |
+
df_reduced,
|
179 |
+
x='Component 1',
|
180 |
+
y='Component 2',
|
181 |
+
color='subfolder',
|
182 |
+
symbol='indicator',
|
183 |
+
labels={'subfolder': 'Effect'},
|
184 |
+
hover_data=hover_data,
|
185 |
+
color_discrete_sequence=px.colors.qualitative.D3
|
186 |
)
|
187 |
|
188 |
+
fig.update_layout(
|
189 |
+
xaxis=dict(showgrid=True),
|
190 |
+
yaxis=dict(showgrid=True),
|
191 |
+
showlegend=True,
|
192 |
+
legend=dict(title="Effect, Indicator", yanchor="top", xanchor="left", borderwidth=1),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
)
|
|
|
194 |
|
195 |
+
return fig
|
196 |
+
|
197 |
+
def sankey(self, cause_threshold=10, indicator_threshold=5, link_opacity=0.4):
|
198 |
+
# Use self.df to avoid reloading data
|
199 |
+
df_filtered = self.df[(self.df[self.cause_columns].sum(axis=1) > 0) &
|
200 |
+
(self.df[self.indicator_columns].sum(axis=1) > 0)]
|
201 |
+
|
202 |
+
# Melt causes and indicators separately, ensuring unique sentence IDs
|
203 |
+
cause_data = df_filtered[['text_id', 'subfolder'] + self.cause_columns].melt(
|
204 |
+
id_vars=['text_id', 'subfolder'], var_name='cause', value_name='count'
|
205 |
+
).query("count > 0").drop_duplicates(['text_id', 'cause'])
|
206 |
+
|
207 |
+
indicator_data = df_filtered[['text_id', 'subfolder'] + self.indicator_columns].melt(
|
208 |
+
id_vars=['text_id', 'subfolder'], var_name='indicator', value_name='count'
|
209 |
+
).query("count > 0").drop_duplicates(['text_id', 'indicator'])
|
210 |
+
|
211 |
+
# Apply threshold filters
|
212 |
+
valid_causes = cause_data['cause'].value_counts()[lambda x: x >= cause_threshold].index
|
213 |
+
valid_indicators = indicator_data['indicator'].value_counts()[lambda x: x >= indicator_threshold].index
|
214 |
+
cause_data = cause_data[cause_data['cause'].isin(valid_causes)]
|
215 |
+
indicator_data = indicator_data[indicator_data['indicator'].isin(valid_indicators)]
|
216 |
+
|
217 |
+
# Create unique cause-indicator-subfolder links by merging cause and indicator data on 'text_id' and 'subfolder'
|
218 |
+
cause_indicator_links = (
|
219 |
+
cause_data.merge(indicator_data, on=['text_id', 'subfolder'])
|
220 |
+
.groupby(['cause', 'indicator']).size().reset_index(name='count')
|
221 |
+
)
|
222 |
+
|
223 |
+
# Aggregate indicator-subfolder counts
|
224 |
+
indicator_subfolder_links = (
|
225 |
+
indicator_data.groupby(['indicator', 'subfolder']).size().reset_index(name='count')
|
226 |
+
)
|
227 |
+
|
228 |
+
# Define unique labels and their order
|
229 |
+
all_labels = list(valid_causes) + list(valid_indicators) + self.df['subfolder'].unique().tolist()
|
230 |
+
|
231 |
+
# Remove prefixes for cleaner labels
|
232 |
+
all_labels_cleaned = [label.replace("cause_", "").replace("indicator_", "") for label in all_labels]
|
233 |
+
label_to_index = {label: idx for idx, label in enumerate(all_labels)}
|
234 |
+
|
235 |
+
# Define a color palette from Plotly's D3 color sequence
|
236 |
+
color_palette = px.colors.qualitative.D3
|
237 |
+
node_colors = [color_palette[i % len(color_palette)] for i in range(len(all_labels))]
|
238 |
+
|
239 |
+
# Define sources, targets, values, and link colors with RGBA opacity
|
240 |
+
sources, targets, values, link_colors = [], [], [], []
|
241 |
+
|
242 |
+
def hex_to_rgba(hex_color, opacity):
|
243 |
+
return f'rgba({int(hex_color[1:3], 16)}, {int(hex_color[3:5], 16)}, {int(hex_color[5:], 16)}, {opacity})'
|
244 |
+
|
245 |
+
# Cause -> Indicator links
|
246 |
+
for _, row in cause_indicator_links.iterrows():
|
247 |
+
if row['cause'] in label_to_index and row['indicator'] in label_to_index:
|
248 |
+
source_idx = label_to_index[row['cause']]
|
249 |
+
target_idx = label_to_index[row['indicator']]
|
250 |
+
sources.append(source_idx)
|
251 |
+
targets.append(target_idx)
|
252 |
+
values.append(row['count'])
|
253 |
+
link_colors.append(hex_to_rgba(node_colors[source_idx], link_opacity))
|
254 |
+
|
255 |
+
# Indicator -> Subfolder links
|
256 |
+
for _, row in indicator_subfolder_links.iterrows():
|
257 |
+
if row['indicator'] in label_to_index and row['subfolder'] in label_to_index:
|
258 |
+
source_idx = label_to_index[row['indicator']]
|
259 |
+
target_idx = label_to_index[row['subfolder']]
|
260 |
+
sources.append(source_idx)
|
261 |
+
targets.append(target_idx)
|
262 |
+
values.append(row['count'])
|
263 |
+
link_colors.append(hex_to_rgba(node_colors[source_idx], link_opacity))
|
264 |
+
|
265 |
+
fig = go.Figure(data=[go.Sankey(
|
266 |
+
node=dict(
|
267 |
+
pad=15,
|
268 |
+
thickness=20,
|
269 |
+
line=dict(color="black", width=0.5),
|
270 |
+
label=all_labels_cleaned,
|
271 |
+
color=node_colors
|
272 |
+
),
|
273 |
+
link=dict(
|
274 |
+
source=sources,
|
275 |
+
target=targets,
|
276 |
+
value=values,
|
277 |
+
color=link_colors
|
278 |
+
)
|
279 |
+
)])
|
280 |
+
|
281 |
+
fig.update_layout(
|
282 |
+
autosize=False,
|
283 |
+
width=800,
|
284 |
+
height=600,
|
285 |
+
font=dict(size=10)
|
286 |
+
)
|
287 |
|
288 |
+
return fig
|