norygano commited on
Commit
45d0933
·
1 Parent(s): 60e75a3
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/plot.cpython-311.pyc
app.py CHANGED
@@ -2,97 +2,102 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForTokenClassification
4
  from annotated_text import annotated_text
 
 
 
 
5
 
6
  # Load the trained model and tokenizer
7
  model_directory = "norygano/causalBERT"
8
  tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
9
  model = AutoModelForTokenClassification.from_pretrained(model_directory)
10
-
11
- # Set model to evaluation mode
12
  model.eval()
13
 
14
  # Define the label map
15
  label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}
16
 
17
- # Streamlit App
18
- st.markdown(
19
  """
20
  <div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
21
  <span>CAUSEN</span>
22
  <span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
23
  </div>
24
  """,
25
- unsafe_allow_html=True
26
  )
27
- st.markdown("[Model](https://huggingface.co/norygano/causalBERT)")
 
28
 
29
- # Add a description with a link to the model
30
- st.write("Tags indicators and causes of explicit attributions of causality. GER only (atm)")
31
 
32
- # Text input for sentences with italic placeholder text
33
- sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
34
- "Autos stehen im Verdacht, Waldsterben zu verursachen.",
35
- "Fußball führt zu Waldschäden.",
36
- "Haustüren tragen zum Betonsterben bei.",
37
- ])
38
- , placeholder="Your Sentences here.")
39
 
40
- # Split the input text into individual sentences
41
- sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
42
 
43
- # Button to run the model
44
- if st.button("Analyze"):
45
- for sentence in sentences:
46
- # Tokenize the sentence
47
- inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
 
 
 
 
48
 
49
- # Run inference
50
- with torch.no_grad():
51
- outputs = model(**inputs)
52
-
53
- # Get the logits and predicted label IDs
54
- logits = outputs.logits
55
- predicted_label_ids = torch.argmax(logits, dim=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Convert token IDs back to tokens
58
- tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
 
59
 
60
- # Map label IDs to human-readable labels
61
- predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
 
 
 
 
 
 
 
62
 
63
- # Reconstruct words from subwords and prepare for annotated_text
64
- annotations = []
65
- current_word = ""
66
- current_label = "O"
67
-
68
- for token, label in zip(tokens, predicted_labels):
69
- if token in ['[CLS]', '[SEP]']: # Exclude special tokens
70
- continue
71
-
72
- if token.startswith("##"):
73
- # Append subword without "##" prefix to the current word
74
- current_word += token[2:]
75
- else:
76
- # If we have accumulated a word, add it to annotations with a space
77
- if current_word:
78
- if current_label != "O":
79
- annotations.append((current_word, current_label))
80
- else:
81
- annotations.append(current_word)
82
- annotations.append(" ") # Add a space between words
83
-
84
- # Start a new word
85
- current_word = token
86
- current_label = label
87
-
88
- # Add the last accumulated word
89
- if current_word:
90
- if current_label != "O":
91
- annotations.append((current_word, current_label))
92
- else:
93
- annotations.append(current_word)
94
 
95
- # Display annotated text
96
- st.write(f"**Sentence:** {sentence}")
97
- annotated_text(*annotations)
98
- st.write("---")
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForTokenClassification
4
  from annotated_text import annotated_text
5
+ import pandas as pd
6
+ import plotly.express as px
7
+ from plot import indicator_chart, causes_chart, scatter_plot
8
+ import os
9
 
10
  # Load the trained model and tokenizer
11
  model_directory = "norygano/causalBERT"
12
  tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
13
  model = AutoModelForTokenClassification.from_pretrained(model_directory)
 
 
14
  model.eval()
15
 
16
  # Define the label map
17
  label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}
18
 
19
+ # Main application
20
+ st.markdown(
21
  """
22
  <div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
23
  <span>CAUSEN</span>
24
  <span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
25
  </div>
26
  """,
27
+ unsafe_allow_html=True
28
  )
29
+ st.markdown("[Model](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv) | [Project](https://www.uni-trier.de/universitaet/fachbereiche-faecher/fachbereich-ii/faecher/germanistik/professurenfachteile/germanistische-linguistik/professoren/prof-dr-martin-wengeler/kontroverse-diskurse/individium-gesellschaft)")
30
+ st.write("Tags indicators and causes in explicit attributions of causality. GER only (atm)")
31
 
32
+ # Create tabs
33
+ tab1, tab2, tab3, tab4 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter"])
34
 
35
+ # Prompt Tab
36
+ with tab1:
37
+ sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
38
+ "Autos stehen im Verdacht, Waldsterben zu verursachen.",
39
+ "Fußball führt zu Waldschäden.",
40
+ "Haustüren tragen zum Betonsterben bei.",
41
+ ]), placeholder="Your Sentences here.")
42
 
43
+ sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
 
44
 
45
+ if st.button("Analyze"):
46
+ for sentence in sentences:
47
+ inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
48
+ with torch.no_grad():
49
+ outputs = model(**inputs)
50
+ logits = outputs.logits
51
+ predicted_label_ids = torch.argmax(logits, dim=2)
52
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
53
+ predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
54
 
55
+ annotations = []
56
+ current_word = ""
57
+ current_label = "O"
58
+ for token, label in zip(tokens, predicted_labels):
59
+ if token in ['[CLS]', '[SEP]']: # Exclude special tokens
60
+ continue
61
+ if token.startswith("##"):
62
+ current_word += token[2:]
63
+ else:
64
+ if current_word:
65
+ if current_label != "O":
66
+ annotations.append((current_word, current_label))
67
+ else:
68
+ annotations.append(current_word)
69
+ annotations.append(" ") # Add a space between words
70
+ current_word = token
71
+ current_label = label
72
+ if current_word:
73
+ if current_label != "O":
74
+ annotations.append((current_word, current_label))
75
+ else:
76
+ annotations.append(current_word)
77
+ st.write(f"**Sentence:** {sentence}")
78
+ annotated_text(*annotations)
79
+ st.write("---")
80
 
81
+ # Research Insights Tab
82
+ with tab2:
83
+ st.write("## Indicators")
84
 
85
+ # Overall
86
+ st.subheader("Overall")
87
+ fig_overall = indicator_chart(chart_type='overall')
88
+ st.plotly_chart(fig_overall, use_container_width=True)
89
+
90
+ # Individual Indicators Chart
91
+ st.subheader("Individual")
92
+ fig_individual = indicator_chart(chart_type='individual')
93
+ st.plotly_chart(fig_individual, use_container_width=True)
94
 
95
+ with tab3:
96
+ st.write("## Causes")
97
+ fig_causes = causes_chart()
98
+ st.plotly_chart(fig_causes, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ with tab4:
101
+ st.write("## Scatter")
102
+ fig_scatter = scatter_plot()
103
+ st.plotly_chart(fig_scatter, use_container_width=True)
data/feature_matrix.tsv ADDED
The diff for this file is too large to render. See raw diff
 
data/indicator_cause_sentence_metadata.tsv ADDED
The diff for this file is too large to render. See raw diff
 
data/indicator_overview.tsv ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Subfolder Total Sentences Indicator Count Share
2
+ Waldsterben_nk 828 Total with Indicators 139 16.79%
3
+ Waldsterben_nk 828 None 131 15.82%
4
+ Waldsterben_nk 828 verantwortung 31 3.74%
5
+ Waldsterben_nk 828 ursache 31 3.74%
6
+ Waldsterben_nk 828 schuld 15 1.81%
7
+ Waldsterben_nk 828 beitragen 15 1.81%
8
+ Waldsterben_nk 828 führen 6 0.72%
9
+ Waldsterben_nk 828 !besprechen 6 0.72%
10
+ Waldsterben_nk 828 wirkung 4 0.48%
11
+ Waldsterben_nk 828 folgen 4 0.48%
12
+ Waldsterben_nk 828 beschleunigen 3 0.36%
13
+ Waldsterben_nk 828 zusammenhang 3 0.36%
14
+ Waldsterben_nk 828 faktor 3 0.36%
15
+ Waldsterben_nk 828 durch 2 0.24%
16
+ Waldsterben_nk 828 grund 2 0.24%
17
+ Waldsterben_nk 828 kosten 2 0.24%
18
+ Waldsterben_nk 828 stecken 1 0.12%
19
+ Waldsterben_nk 828 steuern 1 0.12%
20
+ Waldsterben_nk 828 zuständig 1 0.12%
21
+ Waldsterben_nk 828 auslöser 1 0.12%
22
+ Waldsterben_nk 828 sorgen 1 0.12%
23
+ Waldsterben_nk 828 wenn-dann 1 0.12%
24
+ Waldsterben_nk 828 schaden 1 0.12%
25
+ Waldsterben_nk 828 angesichts 1 0.12%
26
+ Waldsterben_nk 828 einfluss 1 0.12%
27
+ Waldsterben_nk 828 stellung 1 0.12%
28
+ Waldsterben_nk 828 zuTun 1 0.12%
29
+ Waldsterben_nk 828 teil 1 0.12%
30
+ Waldsterben_nk 828 rolle 1 0.12%
31
+ Waldsterben_nk 828 bedeuten 1 0.12%
32
+ Waldsterben_nk 828 sünder 1 0.12%
33
+ Bienensterben_nk 281 Total with Indicators 126 44.84%
34
+ Bienensterben_nk 281 None 123 43.77%
35
+ Bienensterben_nk 281 verantwortung 40 14.23%
36
+ Bienensterben_nk 281 ursache 19 6.76%
37
+ Bienensterben_nk 281 grund 14 4.98%
38
+ Bienensterben_nk 281 beitragen 10 3.56%
39
+ Bienensterben_nk 281 schuld 7 2.49%
40
+ Bienensterben_nk 281 teil 6 2.14%
41
+ Bienensterben_nk 281 führen 5 1.78%
42
+ Bienensterben_nk 281 verbindung 4 1.42%
43
+ Bienensterben_nk 281 kommen 3 1.07%
44
+ Bienensterben_nk 281 faktor 3 1.07%
45
+ Bienensterben_nk 281 zuTun 3 1.07%
46
+ Bienensterben_nk 281 folgen 2 0.71%
47
+ Bienensterben_nk 281 auslöser 2 0.71%
48
+ Bienensterben_nk 281 erklärung 2 0.71%
49
+ Bienensterben_nk 281 wirkung 2 0.71%
50
+ Bienensterben_nk 281 einhergehen 1 0.36%
51
+ Bienensterben_nk 281 wegen 1 0.36%
52
+ Bienensterben_nk 281 durch 1 0.36%
53
+ Bienensterben_nk 281 zusammenhang 1 0.36%
54
+ Bienensterben_nk 281 wenn-dann 1 0.36%
55
+ Bienensterben_nk 281 wundern 1 0.36%
56
+ Bienensterben_nk 281 handeln 1 0.36%
57
+ Bienensterben_nk 281 einfluss 1 0.36%
58
+ Bienensterben_nk 281 !besprechen 1 0.36%
59
+ Artensterben_nk 539 Total with Indicators 141 26.16%
60
+ Artensterben_nk 539 None 141 26.16%
61
+ Artensterben_nk 539 ursache 27 5.01%
62
+ Artensterben_nk 539 beitragen 21 3.90%
63
+ Artensterben_nk 539 verantwortung 19 3.53%
64
+ Artensterben_nk 539 grund 15 2.78%
65
+ Artensterben_nk 539 schuld 11 2.04%
66
+ Artensterben_nk 539 führen 11 2.04%
67
+ Artensterben_nk 539 kommen 4 0.74%
68
+ Artensterben_nk 539 zusammenhang 4 0.74%
69
+ Artensterben_nk 539 einfluss 4 0.74%
70
+ Artensterben_nk 539 teil 4 0.74%
71
+ Artensterben_nk 539 faktor 3 0.56%
72
+ Artensterben_nk 539 wirkung 2 0.37%
73
+ Artensterben_nk 539 erklärung 2 0.37%
74
+ Artensterben_nk 539 folgen 2 0.37%
75
+ Artensterben_nk 539 rolle 2 0.37%
76
+ Artensterben_nk 539 auslöser 1 0.19%
77
+ Artensterben_nk 539 erzeugen 1 0.19%
78
+ Artensterben_nk 539 stecken 1 0.19%
79
+ Artensterben_nk 539 sünder 1 0.19%
80
+ Artensterben_nk 539 durch 1 0.19%
81
+ Artensterben_nk 539 bedingen 1 0.19%
82
+ Artensterben_nk 539 zuTun 1 0.19%
83
+ Artensterben_nk 539 fördern 1 0.19%
84
+ Artensterben_nk 539 treiben 1 0.19%
85
+ Insektensterben_nk 253 Total with Indicators 66 26.09%
86
+ Insektensterben_nk 253 None 60 23.72%
87
+ Insektensterben_nk 253 ursache 12 4.74%
88
+ Insektensterben_nk 253 verantwortung 8 3.16%
89
+ Insektensterben_nk 253 grund 7 2.77%
90
+ Insektensterben_nk 253 beitragen 6 2.37%
91
+ Insektensterben_nk 253 !besprechen 5 1.98%
92
+ Insektensterben_nk 253 rolle 4 1.58%
93
+ Insektensterben_nk 253 schuld 3 1.19%
94
+ Insektensterben_nk 253 faktor 3 1.19%
95
+ Insektensterben_nk 253 zusammenhang 2 0.79%
96
+ Insektensterben_nk 253 zuTun 2 0.79%
97
+ Insektensterben_nk 253 teil 2 0.79%
98
+ Insektensterben_nk 253 folgen 1 0.40%
99
+ Insektensterben_nk 253 kosten 1 0.40%
100
+ Insektensterben_nk 253 durch 1 0.40%
101
+ Insektensterben_nk 253 treiben 1 0.40%
102
+ Insektensterben_nk 253 bedeuten 1 0.40%
103
+ Insektensterben_nk 253 relevant 1 0.40%
104
+ Insektensterben_nk 253 einfluss 1 0.40%
105
+ Insektensterben_nk 253 stecken 1 0.40%
plot.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import plotly.express as px
3
+ import os
4
+ import umap
5
+ from sklearn.preprocessing import StandardScaler
6
+
7
+ def indicator_chart(chart_type='overall'):
8
+ data_file = os.path.join('data', 'indicator_overview.tsv')
9
+ df = pd.read_csv(data_file, sep='\t')
10
+
11
+ if chart_type == 'overall':
12
+ df_filtered = df[df['Indicator'] == 'Total with Indicators'].copy()
13
+ total_sentences_per_subfolder = df.groupby('Subfolder')['Total Sentences'].first().to_dict()
14
+ df_filtered['Total Sentences'] = df_filtered['Subfolder'].map(total_sentences_per_subfolder)
15
+ df_filtered['Indicator_Share'] = df_filtered['Count'] / df_filtered['Total Sentences']
16
+ df_filtered['Indicator_Share_Text'] = (df_filtered['Indicator_Share'] * 100).round(2).astype(str) + '%'
17
+
18
+ fig = px.bar(
19
+ df_filtered,
20
+ x='Subfolder',
21
+ y='Indicator_Share',
22
+ labels={'Indicator_Share': 'Share of Sentences with Indicators', 'Subfolder': ''},
23
+ color='Subfolder',
24
+ text='Indicator_Share_Text',
25
+ color_discrete_sequence=px.colors.qualitative.D3,
26
+ custom_data=['Total Sentences', 'Count']
27
+ )
28
+
29
+ fig.update_traces(
30
+ hovertemplate=(
31
+ '<b>%{x}</b><br>' +
32
+ 'Share with Indicators: %{y:.1%}<br>' +
33
+ 'Total Sentences: %{customdata[0]}<br>' +
34
+ 'Sentences with Indicators: %{customdata[1]}<extra></extra>'
35
+ ),
36
+ textposition='inside',
37
+ texttemplate='%{text}',
38
+ textfont=dict(color='rgb(255, 255, 255)'),
39
+ insidetextanchor='middle',
40
+ )
41
+
42
+ elif chart_type == 'individual':
43
+ min_value = 5
44
+ exclude_indicators = ['!besprechen']
45
+ df_filtered = df[~df['Indicator'].isin(['Total with Indicators', 'None'] + exclude_indicators)].copy()
46
+ indicators_meeting_threshold = df_filtered[df_filtered['Count'] >= min_value]['Indicator'].unique()
47
+ df_filtered = df_filtered[df_filtered['Indicator'].isin(indicators_meeting_threshold)]
48
+ df_filtered['Indicator'] = df_filtered['Indicator'].str.capitalize()
49
+
50
+ fig = px.bar(
51
+ df_filtered,
52
+ x='Subfolder',
53
+ y='Count',
54
+ color='Indicator',
55
+ barmode='group',
56
+ labels={'Count': 'Occurrences', 'Subfolder': '', 'Indicator': ' <b>INDICATOR</b>'},
57
+ color_discrete_sequence=px.colors.qualitative.D3
58
+ )
59
+
60
+ fig.update_traces(
61
+ texttemplate='%{y}',
62
+ textposition='inside',
63
+ textfont=dict(color='rgb(255, 255, 255)'),
64
+ )
65
+
66
+ fig.update_layout(
67
+ xaxis=dict(showline=True),
68
+ yaxis=dict(showticklabels=True, title=''),
69
+ bargap=0.05,
70
+ showlegend=(chart_type == 'individual')
71
+ )
72
+
73
+ return fig
74
+
75
+ def causes_chart():
76
+ data_file = os.path.join('data', 'indicator_cause_sentence_metadata.tsv')
77
+ df = pd.read_csv(data_file, sep='\t')
78
+
79
+ # Threshold
80
+ min_value = 30
81
+ df_filtered = df[df['cause'] != 'N/A'].copy()
82
+ causes_meeting_threshold = df_filtered.groupby('cause')['cause'].count()[lambda x: x >= min_value].index
83
+ df_filtered = df_filtered[df_filtered['cause'].isin(causes_meeting_threshold)]
84
+ df_filtered['cause'] = df_filtered['cause'].str.capitalize()
85
+
86
+ fig = px.bar(
87
+ df_filtered.groupby(['subfolder', 'cause']).size().reset_index(name='Count'),
88
+ x='subfolder',
89
+ y='Count',
90
+ color='cause',
91
+ barmode='group',
92
+ labels={'Count': 'Occurrences', 'subfolder': '', 'cause': '<b>CAUSE</b>'},
93
+ color_discrete_sequence=px.colors.qualitative.G10,
94
+ )
95
+
96
+ fig.update_layout(
97
+ xaxis=dict(showline=True),
98
+ yaxis=dict(showticklabels=True, title=''),
99
+
100
+ )
101
+
102
+ fig.update_traces(
103
+ texttemplate='%{y}',
104
+ textposition='inside',
105
+ textfont=dict(color='rgb(255, 255, 255)'),
106
+ insidetextanchor='middle',
107
+ )
108
+
109
+ return fig
110
+
111
+ def scatter_plot(include_modality=False):
112
+ data_file = os.path.join('data', 'feature_matrix.tsv')
113
+ df = pd.read_csv(data_file, sep='\t')
114
+
115
+ # Exclude sentences without any indicators (all indicator columns are 0), causes, or modalities (if included)
116
+ indicator_columns = [col for col in df.columns if col.startswith('indicator_')]
117
+ cause_columns = [col for col in df.columns if col.startswith('cause_')]
118
+ modality_columns = [col for col in df.columns if col.startswith('modality_')]
119
+
120
+ df_filtered = df[(df[indicator_columns].sum(axis=1) > 0) |
121
+ (df[cause_columns].sum(axis=1) > 0)]
122
+
123
+ # Exclude indicator '!besprechen'
124
+ indicator_columns = [col for col in indicator_columns if 'indicator_!besprechen' not in col]
125
+
126
+ # Limit indicators to those that occur at least 10 times
127
+ indicator_counts = df_filtered[indicator_columns].sum()
128
+ indicators_to_keep = indicator_counts[indicator_counts >= 10].index.tolist()
129
+
130
+ # Further filter to exclude entries without any valid indicators
131
+ df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0]
132
+
133
+ # Exclude non-feature columns (metadata and sentence text) for dimensionality reduction
134
+ columns_to_drop = ['subfolder']
135
+ if not include_modality:
136
+ columns_to_drop += modality_columns # Drop modality columns if not included
137
+
138
+ features = df_filtered.drop(columns=columns_to_drop)
139
+
140
+ # Fill NaN values with 0 for the feature matrix
141
+ features_clean = features.fillna(0)
142
+
143
+ # Store the relevant metadata separately to ensure it is aligned correctly with the dimensionality reduction results
144
+ metadata = df_filtered[['subfolder']].copy()
145
+ # Remove the 'indicator_' prefix for indicators and ensure only indicators with at least 10 occurrences are included
146
+ metadata['indicator'] = df_filtered[indicators_to_keep].apply(lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]), axis=1)
147
+ # Collect all non-zero causes as a string (multiple causes per sentence)
148
+ metadata['cause'] = df_filtered[cause_columns].apply(lambda row: ', '.join([cause.replace('cause_', '') for cause in cause_columns if row[cause] > 0]), axis=1)
149
+
150
+ # Perform UMAP dimensionality reduction
151
+ reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, metric='cosine')
152
+ reduced_features = reducer.fit_transform(features_clean)
153
+ df_reduced = pd.DataFrame(reduced_features, columns=['Component 1', 'Component 2'])
154
+ df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1)
155
+
156
+ # Plotting the scatter plot with Plotly Express
157
+ hover_data = {'cause'}
158
+ if include_modality:
159
+ hover_data['Modality'] = True
160
+
161
+ fig = px.scatter(
162
+ df_reduced,
163
+ x='Component 1',
164
+ y='Component 2',
165
+ color='subfolder',
166
+ hover_data=hover_data,
167
+ labels={'Component 1': 'UMAP Dim 1', 'Component 2': 'UMAP Dim 2'},
168
+ color_discrete_sequence=px.colors.qualitative.Plotly
169
+ )
170
+
171
+ fig.update_layout(
172
+ xaxis=dict(showgrid=False),
173
+ yaxis=dict(showgrid=False),
174
+ showlegend=True
175
+ )
176
+
177
+ return fig
178
+