ayushnoori commited on
Commit
11b8c2d
·
1 Parent(s): 9a01b6b

Add statistics and recall to validation, but for ORIGINAL EDGE TYPE

Browse files
Files changed (1) hide show
  1. pages/validate.py +184 -77
pages/validate.py CHANGED
@@ -13,6 +13,9 @@ import matplotlib.pyplot as plt
13
  plt.rcParams['font.sans-serif'] = 'Arial'
14
  import matplotlib.colors as mcolors
15
 
 
 
 
16
  # Custom and other imports
17
  import project_config
18
  from utils import load_kg, load_kg_edges
@@ -45,85 +48,189 @@ relation = st.session_state.query['relation']
45
  target_node_type = st.session_state.query['target_node_type']
46
  predictions = st.session_state.predictions
47
 
48
- kg_nodes = load_kg()
49
- kg_edges = load_kg_edges()
50
-
51
- # Convert tuple to hex
52
- def rgba_to_hex(rgba):
53
- return mcolors.to_hex(rgba[:3])
54
-
55
- with st.spinner('Searching known relationships...'):
56
-
57
- # Subset existing edges
58
- edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
59
- edge_subset = edge_subset[edge_subset.y_type == target_node_type]
60
-
61
- # Merge edge subset with predictions
62
- edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
63
- edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
64
- edges_in_kg = edges_in_kg.drop(columns = 'y_id')
65
-
66
- # Rename relation to ground-truth
67
- edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
68
- edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
69
-
70
- # If there exist edges in KG
71
- if len(edges_in_kg) > 0:
72
-
73
- with st.spinner('Saving validation results...'):
74
-
75
- # Cast long to wide
76
- val_results = edge_subset[['relation', 'y_id']].pivot_table(index='y_id', columns='relation', aggfunc='size', fill_value=0)
77
- val_results = (val_results > 0).astype(int).reset_index()
78
- val_results.columns = [val_results.columns[0]] + [x.replace('_', ' ').title() for x in val_results.columns[1:]]
79
-
80
- # Save validation results to session state
81
- st.session_state.validation = val_results
82
-
83
- with st.spinner('Plotting known relationships...'):
84
-
85
- # Define a color map for different relations
86
- color_map = plt.get_cmap('tab10')
87
-
88
- # Group by relation and create separate plots
89
- relations = edges_in_kg['Known Relation'].unique()
90
- for idx, relation in enumerate(relations):
91
-
92
- relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
93
-
94
- # Get a color from the color map
95
- color = color_map(idx % color_map.N)
96
-
97
- fig, ax = plt.subplots(figsize=(10, 3))
98
- ax.plot(predictions['Rank'], predictions['Score'])
99
- ax.set_xlabel('Rank', fontsize=12)
100
- ax.set_ylabel('Score', fontsize=12)
101
- ax.set_xlim(1, predictions['Rank'].max())
102
-
103
- for i, node in relation_data.iterrows():
104
- ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
105
- # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
106
-
107
- # ax.set_title(f'{relation.replace("_", "-")}')
108
- # ax.legend()
109
- color_hex = rgba_to_hex(color)
110
 
111
- # Write header in color of relation
112
- st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
113
 
114
- # Show plot
115
- st.pyplot(fig)
116
 
117
- # Drop known relation column
118
- relation_data = relation_data.drop(columns = 'Known Relation')
119
- if target_node_type not in ['disease', 'anatomy']:
120
- st.dataframe(relation_data, use_container_width=True,
121
- column_config={"Database": st.column_config.LinkColumn(width = "small",
122
- help = "Click to visit external database.",
123
- display_text = st.session_state.display_database)})
124
- else:
125
- st.dataframe(relation_data, use_container_width=True)
126
 
127
- else:
 
 
 
 
 
 
 
128
 
129
- st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  plt.rcParams['font.sans-serif'] = 'Arial'
14
  import matplotlib.colors as mcolors
15
 
16
+ # Import metrics
17
+ from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score
18
+
19
  # Custom and other imports
20
  import project_config
21
  from utils import load_kg, load_kg_edges
 
48
  target_node_type = st.session_state.query['target_node_type']
49
  predictions = st.session_state.predictions
50
 
51
+ @st.experimental_fragment()
52
+ def plot_options():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ st.markdown("<h5 style='margin-top: 10px;'>Plotting Options</h5>", unsafe_allow_html=True)
 
55
 
56
+ # Checkbox to show lines
57
+ show_lines = st.checkbox('Show rug plot of existing edges?', value = False)
58
 
59
+ # Slider for x-axis limit
60
+ axis_limits = st.slider('Define the range of ranks to visualize.',
61
+ min_value=0, max_value=predictions['Rank'].max(),
62
+ value=(0, predictions['Rank'].max()), step=1000)
63
+
64
+ # Update session state
65
+ st.session_state.show_lines = show_lines
66
+ st.session_state.axis_limits = axis_limits
 
67
 
68
+ # Get plot options
69
+ plot_options()
70
+
71
+ # Set default options
72
+ if 'show_lines' not in st.session_state:
73
+ st.session_state.show_lines = False
74
+ if 'axis_limits' not in st.session_state:
75
+ st.session_state.axis_limits = (0, predictions['Rank'].max())
76
 
77
+ # Button to update plot
78
+ col1, col2, col3 = st.columns([4, 2, 4])
79
+ with col2:
80
+ update_button = st.button('Generate Plot')
81
+
82
+ # Horizontal line
83
+ st.markdown('---')
84
+
85
+ if update_button:
86
+
87
+ kg_nodes = load_kg()
88
+ kg_edges = load_kg_edges()
89
+
90
+ # Convert tuple to hex
91
+ def rgba_to_hex(rgba):
92
+ return mcolors.to_hex(rgba[:3])
93
+
94
+ with st.spinner('Searching known relationships...'):
95
+
96
+ # Subset existing edges
97
+ edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
98
+ edge_subset = edge_subset[edge_subset.y_type == target_node_type]
99
+
100
+ # Merge edge subset with predictions
101
+ edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
102
+ edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
103
+ edges_in_kg = edges_in_kg.drop(columns = 'y_id')
104
+
105
+ # Rename relation to ground-truth
106
+ edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
107
+ edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
108
+
109
+ # If there exist edges in KG
110
+ if len(edges_in_kg) > 0:
111
+
112
+ with st.spinner('Saving validation results...'):
113
+
114
+ # Cast long to wide
115
+ val_results = edge_subset[['relation', 'y_id']].pivot_table(index='y_id', columns='relation', aggfunc='size', fill_value=0)
116
+ val_results = (val_results > 0).astype(int).reset_index()
117
+ val_results.columns = [val_results.columns[0]] + [x.replace('_', ' ').title() for x in val_results.columns[1:]]
118
+
119
+ # Save validation results to session state
120
+ st.session_state.validation = val_results
121
+
122
+ with st.spinner('Plotting known relationships...'):
123
+
124
+ # Define a color map for different relations
125
+ color_map = plt.get_cmap('tab10')
126
+
127
+ # Group by relation and create separate plots
128
+ relations = edges_in_kg['Known Relation'].unique()
129
+ for idx, relation in enumerate(relations):
130
+
131
+ relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
132
+
133
+ # Get a color from the color map
134
+ color = color_map(idx % color_map.N)
135
+
136
+ fig, ax = plt.subplots(figsize=(10, 5))
137
+ ax.plot(predictions['Rank'], predictions['Score'], color = 'black', linewidth = 1.5, zorder = 2)
138
+ ax.set_xlabel('Rank', fontsize=12)
139
+ ax.set_ylabel('Score', fontsize=12)
140
+ # ax.set_xlim(1, predictions['Rank'].max())
141
+ # ax.set_xlim(axis_limits)
142
+ ax.set_xlim(st.session_state.axis_limits)
143
+
144
+ for i, node in relation_data.iterrows():
145
+ if st.session_state.show_lines:
146
+ ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'], zorder = 3)
147
+ ax.scatter(node['Rank'], node['Score'], color=color, zorder=3) # s=15
148
+ # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
149
+
150
+ # Also calculate and plot recall at K
151
+ ax2 = ax.twinx()
152
+
153
+ # Calculate recall at K for all Rank
154
+ recall_at_k = []
155
+ for k in range(1, predictions['Rank'].max() + 1):
156
+ recall = 100*len(relation_data[relation_data['Rank'] <= k]) / len(relation_data)
157
+ recall_at_k.append(recall)
158
+
159
+ ax2.plot(range(1, predictions['Rank'].max() + 1), recall_at_k,
160
+ color = 'red', linestyle = '--', label = 'Recall at K', zorder = 4, linewidth = 2)
161
+
162
+ # Set labels
163
+ ax2.set_ylabel('Recall at K (%)', fontsize=12, color='red')
164
+
165
+ # Add grid
166
+ ax.grid(True, linestyle=':', alpha=0.5, zorder=0)
167
+
168
+ # ax.set_title(f'{relation.replace("_", "-")}')
169
+ # ax.legend()
170
+ color_hex = rgba_to_hex(color)
171
+
172
+ # Write header in color of relation
173
+ st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h3>", unsafe_allow_html=True)
174
+
175
+ # Show plot
176
+ st.pyplot(fig)
177
+
178
+ # Create recall at K table
179
+ k_vals = [10, 50, 100, 500, 1000, 5000, 10000]
180
+ recall_at_k = []
181
+ for k in k_vals:
182
+ recall = 100*len(relation_data[relation_data['Rank'] <= k]) / len(relation_data)
183
+ recall = f"{recall:.2f}%"
184
+ recall_at_k.append(recall)
185
+ recall_df = pd.DataFrame({'K': k_vals, 'Recall': recall_at_k})
186
+
187
+ # Transpose and display recall at K
188
+ recall_df = recall_df.T
189
+ recall_df.columns = [f"k = {k:.0f}" for k in recall_df.iloc[0]]
190
+ recall_df = recall_df.drop('K')
191
+ st.markdown('**Recall at $k$:**')
192
+ st.dataframe(recall_df, use_container_width=True)
193
+
194
+ # Compute other statistics
195
+ st.markdown('**Statistics:**')
196
+
197
+ # Binarize score
198
+ pred_threshold = 0.5
199
+ raw_score = predictions['Score']
200
+ binary_score = (raw_score > pred_threshold).astype(int)
201
+ true_label = np.zeros(len(predictions))
202
+
203
+ # Set true label to 1 for known relations
204
+
205
+ # Reset index
206
+ predictions_idx = predictions.copy().reset_index(drop = True)
207
+ true_label[predictions_idx[predictions_idx['ID'].isin(relation_data['ID'])].index] = 1
208
+
209
+ # Compute scores
210
+ accuracy = accuracy_score(true_label, binary_score)
211
+ ap = average_precision_score(true_label, raw_score)
212
+ f1 = f1_score(true_label, binary_score, average = 'micro')
213
+ try:
214
+ auc = roc_auc_score(true_label, raw_score)
215
+ except ValueError:
216
+ auc = 0.5
217
+
218
+ # Create dataframe
219
+ stats_df = pd.DataFrame({'Accuracy': [accuracy], 'AUC': [auc], 'AP': [ap], 'F1': [f1]})
220
+ stats_df.index = ["Value"]
221
+ st.dataframe(stats_df, use_container_width=True)
222
+
223
+ # Drop known relation column
224
+ st.markdown('**Known Relationships:**')
225
+ relation_data = relation_data.drop(columns = 'Known Relation')
226
+ if target_node_type not in ['disease', 'anatomy']:
227
+ st.dataframe(relation_data, use_container_width=True,
228
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
229
+ help = "Click to visit external database.",
230
+ display_text = st.session_state.display_database)})
231
+ else:
232
+ st.dataframe(relation_data, use_container_width=True)
233
+
234
+ else:
235
+
236
+ st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")