Erva Ulusoy commited on
Commit
51641fb
·
1 Parent(s): 8f4b741

added coloring to predicted/ground truth go term edges

Browse files
Files changed (1) hide show
  1. visualize_kg.py +37 -29
visualize_kg.py CHANGED
@@ -63,7 +63,6 @@ def _gather_protein_edges(data, protein_id):
63
  return protein_edges
64
 
65
  def _filter_edges(protein_id, protein_edges, prediction_df, limit=10):
66
-
67
  filtered_edges = {}
68
 
69
  prediction_categories = prediction_df['GO_category'].unique()
@@ -75,32 +74,35 @@ def _filter_edges(protein_id, protein_edges, prediction_df, limit=10):
75
  if edges is None or len(edges) == 0:
76
  continue
77
 
78
- if edge_type[2] in prediction_categories:
79
- category_mask = (prediction_df['GO_category'] == go_category_reverse_mapping[edge_type[2]]) & (prediction_df['UniProt_ID'] == protein_id)
80
- category_predictions = prediction_df[category_mask]
81
-
82
- if len(category_predictions) > 0:
83
- category_predictions = category_predictions.sort_values(by='Probability', ascending=False)
84
-
85
- # Convert set to list for easier filtering
86
- edges_list = list(edges)
87
-
88
- # Filter valid edges and store with probabilities
89
- valid_edges = []
90
- for _, row in category_predictions.iterrows():
91
- term = row['GO_ID']
92
- prob = row['Probability']
93
- matching_edges = [(edge, prob) for edge in edges_list if edge[1] == term]
94
- valid_edges.extend(matching_edges)
95
- if len(valid_edges) >= limit:
96
- break
97
- filtered_edges[edge_type] = valid_edges # Remove set conversion to preserve probabilities
 
 
 
98
  else:
99
- # If no predictions, include all edges up to limit without probabilities
100
- filtered_edges[edge_type] = [(edge, None) for edge in list(edges)[:limit]]
101
  else:
102
- # For non-GO edges, include all edges up to limit without probabilities
103
- filtered_edges[edge_type] = [(edge, None) for edge in list(edges)[:limit]]
104
 
105
  return filtered_edges
106
 
@@ -186,7 +188,7 @@ def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10):
186
  source_type, relation_type, target_type = edge_type
187
 
188
  for edge_info in edges:
189
- edge, probability = edge_info
190
  source, target = edge[0], edge[1]
191
  source_str = str(source)
192
  target_str = str(target)
@@ -218,10 +220,16 @@ def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10):
218
  # Add edge with relationship type and probability as label
219
  edge_label = f"{relation_type}"
220
  if probability is not None:
221
- edge_label += f"(P={probability:.2f})"
 
 
 
 
 
222
  net.add_edge(source_str, target_str,
223
  label=edge_label,
224
- color='#666666',
 
225
  title=edge_label,
226
  length=200,
227
  smooth={'type': 'curvedCW', 'roundness': 0.1})
@@ -229,7 +237,7 @@ def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10):
229
  net.add_edge(source_str, target_str,
230
  label=edge_label,
231
  font={'size': 0},
232
- color='#666666',
233
  title=edge_label,
234
  length=200,
235
  smooth={'type': 'curvedCW', 'roundness': 0.1})
 
63
  return protein_edges
64
 
65
  def _filter_edges(protein_id, protein_edges, prediction_df, limit=10):
 
66
  filtered_edges = {}
67
 
68
  prediction_categories = prediction_df['GO_category'].unique()
 
74
  if edges is None or len(edges) == 0:
75
  continue
76
 
77
+ if edge_type[2].startswith('GO_term'): # Check if it's any GO term edge
78
+ if edge_type[2] in prediction_categories:
79
+ # Handle edges for GO terms that are in prediction_df
80
+ category_mask = (prediction_df['GO_category'] == go_category_reverse_mapping[edge_type[2]]) & (prediction_df['UniProt_ID'] == protein_id)
81
+ category_predictions = prediction_df[category_mask]
82
+
83
+ if len(category_predictions) > 0:
84
+ category_predictions = category_predictions.sort_values(by='Probability', ascending=False)
85
+ edges_set = set(edges) # Convert to set for O(1) lookup
86
+
87
+ valid_edges = []
88
+ for _, row in category_predictions.iterrows():
89
+ term = row['GO_ID']
90
+ prob = row['Probability']
91
+ edge = (protein_id, term)
92
+ is_ground_truth = edge in edges_set
93
+ valid_edges.append((edge, prob, is_ground_truth))
94
+ if len(valid_edges) >= limit:
95
+ break
96
+ filtered_edges[edge_type] = valid_edges
97
+ else:
98
+ # If no predictions but it's a GO category in prediction_df
99
+ filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]]
100
  else:
101
+ # For GO terms not in prediction_df, mark them as ground truth with blue color
102
+ filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]]
103
  else:
104
+ # For non-GO edges, include all edges up to limit
105
+ filtered_edges[edge_type] = [(edge, None, True) for edge in list(edges)[:limit]]
106
 
107
  return filtered_edges
108
 
 
188
  source_type, relation_type, target_type = edge_type
189
 
190
  for edge_info in edges:
191
+ edge, probability, is_ground_truth = edge_info
192
  source, target = edge[0], edge[1]
193
  source_str = str(source)
194
  target_str = str(target)
 
220
  # Add edge with relationship type and probability as label
221
  edge_label = f"{relation_type}"
222
  if probability is not None:
223
+ if probability == 'no_pred':
224
+ edge_color = '#219ebc'
225
+ edge_label += '(P=Not generated)'
226
+ else:
227
+ edge_label += f"(P={probability:.2f})"
228
+ edge_color = '#c1121f' if is_ground_truth else '#219ebc'
229
  net.add_edge(source_str, target_str,
230
  label=edge_label,
231
+ font={'size': 0},
232
+ color=edge_color,
233
  title=edge_label,
234
  length=200,
235
  smooth={'type': 'curvedCW', 'roundness': 0.1})
 
237
  net.add_edge(source_str, target_str,
238
  label=edge_label,
239
  font={'size': 0},
240
+ color='#666666', # Keep default gray for non-GO edges
241
  title=edge_label,
242
  length=200,
243
  smooth={'type': 'curvedCW', 'roundness': 0.1})