ayushnoori commited on
Commit
48e3a32
·
1 Parent(s): d4ca2d2

Update input

Browse files
Files changed (5) hide show
  1. pages/input.py +26 -35
  2. pages/predict.py +16 -0
  3. pages/split.py +2 -2
  4. pages/validate.py +8 -3
  5. project_config.py +1 -0
pages/input.py CHANGED
@@ -31,42 +31,32 @@ with st.spinner('Loading disease splits...'):
31
  # Load from Kempner using sync_data.sh
32
  disease_splits = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'disease_splits.csv',
33
  dtype = {'node_index': str, 'disease_split_index': str})
34
-
 
 
 
35
  else:
36
 
37
  # Read disease splits from HF
38
  disease_splits = hf_hub_download(repo_id=project_config.HF_REPO,
39
- filename='disease_split/disease_splits.csv',
40
  token=st.secrets["HF_TOKEN"], repo_type="dataset")
41
-
42
- # Group disease splits by disease_split_index column
43
- disease_splits_grouped = disease_splits.groupby('disease_split_index').size().reset_index(name='node_count')
44
-
45
- # Subset to unique disease splits
46
- splits_df =disease_splits[disease_splits['node_index'] == disease_splits['disease_split_index']]
47
- splits_df = splits_df.drop_duplicates(subset='disease_split_index').reset_index(drop=True)
48
- splits_df = splits_df[['node_index', 'node_name', 'node_id']]
49
-
50
- # Merge with counts
51
- splits_df = splits_df.merge(disease_splits_grouped, left_on='node_index', right_on='disease_split_index', how='left')
52
- splits_df = splits_df.drop(columns='disease_split_index')
53
- splits_df['node_name'] = splits_df['node_name'].str.replace(' \\(disease\\)', '', regex=True)
54
-
55
- # Add row for all to beginning
56
- splits_df['node_index'] = splits_df['node_index'].astype(str)
57
- splits_df = pd.concat([pd.DataFrame([['all', 'all diseases', None, disease_splits.shape[0]]], columns=splits_df.columns), splits_df], ignore_index=True)
58
-
59
- # For each disease split, count number of edges (number of rows in CSV file in disease_splits directory)
60
- # Do not read file in
61
- edge_counts = []
62
- for index, row in splits_df.iterrows():
63
- # Count lines
64
- file_name = project_config.DATA_DIR / 'disease_splits' / 'split_edges' / f'{row["node_index"]}.csv'
65
- edge_count = int(subprocess.check_output(['wc', '-l', file_name]).split()[0]) - 1
66
- edge_counts.append(edge_count)
67
-
68
- # Add edge counts to splits_df
69
- splits_df['edge_count'] = edge_counts
70
 
71
  # Get list of available modles
72
  model_files = os.listdir(project_config.MODEL_DIR / 'embeddings')
@@ -92,7 +82,6 @@ with st.spinner('Loading disease splits...'):
92
  # Get available models, only keep latest version per split
93
  avail_models = pd.DataFrame([get_model_metadata(f) for f in model_files])
94
  avail_models = avail_models.sort_values('date', ascending=False).drop_duplicates('test').reset_index(drop=True)
95
- # avail_models.loc[avail_models['test'] == 'all', 'test'] = 'all diseases'
96
 
97
  # Add column to indicate if model is available
98
  splits_df['available'] = splits_df['node_index'].isin(avail_models['test'])
@@ -104,9 +93,12 @@ with st.spinner('Loading disease splits...'):
104
  ####################################################################################################
105
 
106
  # Select disease split from splits with available models
 
 
 
107
  # Make dictionary with node_index: node_name, where name is value shown but index is used for query
108
- # split_options = splits_df[splits_df['available']].copy()
109
- split_options = splits_df.copy()
110
  split_options = split_options.set_index('node_index')['node_name'].to_dict()
111
 
112
  # Check if split is in session state
@@ -119,7 +111,6 @@ with st.spinner('Loading disease splits...'):
119
  index = split_index)
120
 
121
  # Show all splits dataframe
122
- splits_display = splits_df[['node_index', 'node_name', 'node_count', 'edge_count', 'available']].copy()
123
  splits_display = splits_display.rename(columns = {'node_index': 'Split ID', 'node_name': 'Disease', 'node_count': 'Node Count', 'edge_count': 'Edge Count', 'available': 'Model Available'})
124
  st.dataframe(splits_display, use_container_width = True, hide_index = True)
125
 
 
31
  # Load from Kempner using sync_data.sh
32
  disease_splits = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'disease_splits.csv',
33
  dtype = {'node_index': str, 'disease_split_index': str})
34
+
35
+ splits_df = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'disease_splits_summary.csv',
36
+ dtype = {'node_index': str, 'disease_split_index': str})
37
+
38
  else:
39
 
40
  # Read disease splits from HF
41
  disease_splits = hf_hub_download(repo_id=project_config.HF_REPO,
42
+ filename='data/disease_splits/disease_splits.csv',
43
  token=st.secrets["HF_TOKEN"], repo_type="dataset")
44
+
45
+ disease_splits = hf_hub_download(repo_id=project_config.HF_REPO,
46
+ filename='data/disease_splits/disease_splits_summary.csv',
47
+ token=st.secrets["HF_TOKEN"], repo_type="dataset")
48
+
49
+ # # For each disease split, count number of edges (number of rows in CSV file in disease_splits directory)
50
+ # # Do not read file in
51
+ # edge_counts = []
52
+ # for index, row in splits_df.iterrows():
53
+ # # Count lines
54
+ # file_name = project_config.DATA_DIR / 'disease_splits' / 'split_edges' / f'{row["node_index"]}.csv'
55
+ # edge_count = int(subprocess.check_output(['wc', '-l', file_name]).split()[0]) - 1
56
+ # edge_counts.append(edge_count)
57
+
58
+ # # Add edge counts to splits_df
59
+ # splits_df['edge_count'] = edge_counts
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # Get list of available modles
62
  model_files = os.listdir(project_config.MODEL_DIR / 'embeddings')
 
82
  # Get available models, only keep latest version per split
83
  avail_models = pd.DataFrame([get_model_metadata(f) for f in model_files])
84
  avail_models = avail_models.sort_values('date', ascending=False).drop_duplicates('test').reset_index(drop=True)
 
85
 
86
  # Add column to indicate if model is available
87
  splits_df['available'] = splits_df['node_index'].isin(avail_models['test'])
 
93
  ####################################################################################################
94
 
95
  # Select disease split from splits with available models
96
+ splits_display = splits_df[['node_index', 'node_name', 'node_count', 'edge_count', 'available']].copy()
97
+ splits_display['node_name'] = splits_display['node_name'].str.replace(' \\(disease\\)', '', regex=True)
98
+
99
  # Make dictionary with node_index: node_name, where name is value shown but index is used for query
100
+ # split_options = splits_display[splits_display['available']].copy()
101
+ split_options = splits_display.copy()
102
  split_options = split_options.set_index('node_index')['node_name'].to_dict()
103
 
104
  # Check if split is in session state
 
111
  index = split_index)
112
 
113
  # Show all splits dataframe
 
114
  splits_display = splits_display.rename(columns = {'node_index': 'Split ID', 'node_name': 'Disease', 'node_count': 'Node Count', 'edge_count': 'Edge Count', 'available': 'Model Available'})
115
  st.dataframe(splits_display, use_container_width = True, hide_index = True)
116
 
pages/predict.py CHANGED
@@ -34,6 +34,13 @@ st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'
34
  # Print current query
35
  st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
36
 
 
 
 
 
 
 
 
37
  @st.cache_data(show_spinner = 'Downloading AI model...')
38
  def get_embeddings():
39
 
@@ -272,3 +279,12 @@ with st.spinner('Computing predictions...'):
272
  # Save to session state
273
  st.session_state.predictions = display_data
274
  st.session_state.display_database = display_database
 
 
 
 
 
 
 
 
 
 
34
  # Print current query
35
  st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
36
 
37
+ # Print split
38
+ split = st.session_state.split
39
+ splits_df = st.session_state.splits_df
40
+ num_nodes = splits_df[splits_df['node_index'] == split]['node_count'].values[0]
41
+ num_edges = splits_df[splits_df['node_index'] == split]['edge_count'].values[0]
42
+ st.markdown(f"**Disease Split:** {st.session_state.split} ({num_nodes} nodes, {num_edges} edges)")
43
+
44
  @st.cache_data(show_spinner = 'Downloading AI model...')
45
  def get_embeddings():
46
 
 
279
  # Save to session state
280
  st.session_state.predictions = display_data
281
  st.session_state.display_database = display_database
282
+
283
+ # If validation not in session state
284
+ if 'validation' not in st.session_state:
285
+
286
+ col1, col2, col3 = st.columns(3)
287
+
288
+ with col2:
289
+ if st.button("Validate Predictions"):
290
+ st.switch_page("pages/validate.py")
pages/split.py CHANGED
@@ -76,7 +76,7 @@ plt.tight_layout()
76
  # Adding labels on top of each bar
77
  for bar in bars:
78
  yval = bar.get_height()
79
- plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', fontsize=12)
80
  plt.ylim(0, max(method_counts['Count'])*1.1)
81
 
82
  # Show plot
@@ -107,7 +107,7 @@ if disease_split_edges.shape[0] > 0:
107
  # Adding labels on top of each bar
108
  for bar in bars:
109
  yval = bar.get_height()
110
- plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', fontsize=12)
111
  plt.ylim(0, max(relation_counts['Count'])*1.1)
112
 
113
  # Show plot
 
76
  # Adding labels on top of each bar
77
  for bar in bars:
78
  yval = bar.get_height()
79
+ plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', ha='center', fontsize=12)
80
  plt.ylim(0, max(method_counts['Count'])*1.1)
81
 
82
  # Show plot
 
107
  # Adding labels on top of each bar
108
  for bar in bars:
109
  yval = bar.get_height()
110
+ plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', ha='center', fontsize=12)
111
  plt.ylim(0, max(relation_counts['Count'])*1.1)
112
 
113
  # Show plot
pages/validate.py CHANGED
@@ -31,9 +31,14 @@ st.subheader("Validate Predictions", divider = "green")
31
  # Print current query
32
  st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
33
 
34
- # Coming soon
35
- # st.write("Coming soon...")
36
-
 
 
 
 
 
37
  source_node_type = st.session_state.query['source_node_type']
38
  source_node = st.session_state.query['source_node']
39
  relation = st.session_state.query['relation']
 
31
  # Print current query
32
  st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
33
 
34
+ # Print split
35
+ split = st.session_state.split
36
+ splits_df = st.session_state.splits_df
37
+ num_nodes = splits_df[splits_df['node_index'] == split]['node_count'].values[0]
38
+ num_edges = splits_df[splits_df['node_index'] == split]['edge_count'].values[0]
39
+ st.markdown(f"**Disease Split:** {st.session_state.split} ({num_nodes} nodes, {num_edges} edges)")
40
+
41
+ # Get query and predictions
42
  source_node_type = st.session_state.query['source_node_type']
43
  source_node = st.session_state.query['source_node']
44
  relation = st.session_state.query['relation']
project_config.py CHANGED
@@ -31,6 +31,7 @@ print(f"VDI: {VDI}")
31
  # Define global variable to check if running locally
32
  hostname, username = check_local_machine()
33
  LOCAL = True if username == 'an583' else False
 
34
 
35
  # Define HF repo variable
36
  HF_REPO = 'ayushnoori/clinical-drug-repurposing'
 
31
  # Define global variable to check if running locally
32
  hostname, username = check_local_machine()
33
  LOCAL = True if username == 'an583' else False
34
+ print(f"LOCAL: {LOCAL}")
35
 
36
  # Define HF repo variable
37
  HF_REPO = 'ayushnoori/clinical-drug-repurposing'