feiyang-cai commited on
Commit
d84b0a6
·
1 Parent(s): 8b9fe11

revise the descriptions

Browse files
Files changed (3) hide show
  1. app.py +20 -12
  2. dataset_descriptions.json +44 -0
  3. utils.py +21 -9
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from huggingface_hub import HfApi, get_collection, list_collections
3
- from utils import MolecularPropertyPredictionModel, task_types, dataset_descriptions
4
  import pandas as pd
5
  import os
6
 
@@ -12,22 +12,26 @@ def get_models():
12
  if item.item_type == "model":
13
  item_name = item.item_id.split("/")[-1]
14
  models[item_name] = item.item_id
15
- assert item_name in task_types, f"{item_name} is not in the task_types"
16
  assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
17
 
18
  return models
19
 
20
  candidate_models = get_models()
21
- properties = list(candidate_models.keys())
 
22
  model = MolecularPropertyPredictionModel(candidate_models)
23
 
24
  def get_description(property_name):
25
- return dataset_descriptions[property_name]
 
26
 
27
  def predict_single_label(smiles, property_name):
 
 
28
  try:
29
- adapter_id = candidate_models[property_name]
30
- info = model.swith_adapter(property_name, adapter_id)
31
 
32
  running_status = None
33
  if info == "keep":
@@ -45,7 +49,8 @@ def predict_single_label(smiles, property_name):
45
  return "NA", running_status
46
 
47
  #prediction = model.predict(smiles, property_name, adapter_id)
48
- prediction = model.predict_single_smiles(smiles, task_types[property_name])
 
49
  if prediction is None:
50
  return "NA", "Invalid SMILES string"
51
 
@@ -60,9 +65,10 @@ def predict_single_label(smiles, property_name):
60
  return prediction, "Prediction is done"
61
 
62
  def predict_file(file, property_name):
 
63
  try:
64
- adapter_id = candidate_models[property_name]
65
- info = model.swith_adapter(property_name, adapter_id)
66
 
67
  running_status = None
68
  if info == "keep":
@@ -81,7 +87,7 @@ def predict_file(file, property_name):
81
 
82
  df = pd.read_csv(file)
83
  # we have already checked the file contains the "smiles" column
84
- df = model.predict_file(df, task_types[property_name])
85
  # we should save this file to the disk to be downloaded
86
  # rename the file to have "_prediction" suffix
87
  prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
@@ -157,10 +163,12 @@ def build_inference():
157
  with gr.Blocks() as demo:
158
  # first row - Dropdown input
159
  #with gr.Row():
160
- dropdown = gr.Dropdown(properties, label="Property", value=properties[0])
 
 
161
  description_box = gr.Textbox(label="Property description", lines=5,
162
  interactive=False,
163
- value=dataset_descriptions[properties[0]])
164
  # third row - Textbox input and prediction label
165
  with gr.Row(equal_height=True):
166
  with gr.Column():
 
1
  import gradio as gr
2
  from huggingface_hub import HfApi, get_collection, list_collections
3
+ from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
4
  import pandas as pd
5
  import os
6
 
 
12
  if item.item_type == "model":
13
  item_name = item.item_id.split("/")[-1]
14
  models[item_name] = item.item_id
15
+ assert item_name in dataset_task_types, f"{item_name} is not in the task_types"
16
  assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
17
 
18
  return models
19
 
20
  candidate_models = get_models()
21
+ properties = [dataset_property_names[item] for item in candidate_models.keys()]
22
+ property_names = list(candidate_models.keys())
23
  model = MolecularPropertyPredictionModel(candidate_models)
24
 
25
  def get_description(property_name):
26
+ property_id = dataset_property_names_to_dataset[property_name]
27
+ return dataset_descriptions[property_id]
28
 
29
  def predict_single_label(smiles, property_name):
30
+ property_id = dataset_property_names_to_dataset[property_name]
31
+
32
  try:
33
+ adapter_id = candidate_models[property_id]
34
+ info = model.swith_adapter(property_id, adapter_id)
35
 
36
  running_status = None
37
  if info == "keep":
 
49
  return "NA", running_status
50
 
51
  #prediction = model.predict(smiles, property_name, adapter_id)
52
+ print("hello4")
53
+ prediction = model.predict_single_smiles(smiles, dataset_task_types[property_id])
54
  if prediction is None:
55
  return "NA", "Invalid SMILES string"
56
 
 
65
  return prediction, "Prediction is done"
66
 
67
  def predict_file(file, property_name):
68
+ property_id = dataset_property_names_to_dataset[property_name]
69
  try:
70
+ adapter_id = candidate_models[property_id]
71
+ info = model.swith_adapter(property_id, adapter_id)
72
 
73
  running_status = None
74
  if info == "keep":
 
87
 
88
  df = pd.read_csv(file)
89
  # we have already checked the file contains the "smiles" column
90
+ df = model.predict_file(df, dataset_task_types[property_id])
91
  # we should save this file to the disk to be downloaded
92
  # rename the file to have "_prediction" suffix
93
  prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
 
163
  with gr.Blocks() as demo:
164
  # first row - Dropdown input
165
  #with gr.Row():
166
+ print(property_names[0].lower())
167
+ print(properties)
168
+ dropdown = gr.Dropdown(properties, label="Property", value=dataset_property_names[property_names[0].lower()])
169
  description_box = gr.Textbox(label="Property description", lines=5,
170
  interactive=False,
171
+ value=dataset_descriptions[property_names[0].lower()])
172
  # third row - Textbox input and prediction label
173
  with gr.Row(equal_height=True):
174
  with gr.Column():
dataset_descriptions.json CHANGED
@@ -1,112 +1,156 @@
1
  {
2
  "ADMET_Caco2_Wang": {
3
  "task_type": "regression",
 
4
  "description": "predict drug permeability, measured in cm/s, using the Caco-2 cell line as an in vitro model to simulate human intestinal tissue permeability",
 
5
  "num_molecules": 906
6
  },
7
  "ADMET_Bioavailability_Ma": {
8
  "task_type": "classification",
 
9
  "description": "predict oral bioavailability with binary labels, indicating the rate and extent a drug becomes available at its site of action",
 
10
  "num_molecules": 640
11
  },
12
  "ADMET_Lipophilicity_AstraZeneca": {
13
  "task_type": "regression",
 
14
  "description": "predict lipophilicity with continuous labels, measured as a log-ratio, indicating a drug's ability to dissolve in lipid environments",
 
15
  "num_molecules": 4200
16
  },
17
  "ADMET_Solubility_AqSolDB": {
18
  "task_type": "regression",
 
19
  "description": "predict aqueous solubility with continuous labels, measured in log mol/L, indicating a drug's ability to dissolve in water",
 
20
  "num_molecules": 9982
21
  },
22
  "ADMET_HIA_Hou": {
23
  "task_type": "classification",
 
24
  "description": "predict human intestinal absorption (HIA) with binary labels, indicating a drug's ability to be absorbed into the bloodstream",
 
25
  "num_molecules": 578
26
  },
27
  "ADMET_Pgp_Broccatelli": {
28
  "task_type": "classification",
 
29
  "description": "predict P-glycoprotein (Pgp) inhibition with binary labels, indicating a drug's potential to alter bioavailability and overcome multidrug resistance",
 
30
  "num_molecules": 1212
31
  },
32
  "ADMET_BBB_Martins": {
33
  "task_type": "classification",
 
34
  "description": "predict blood-brain barrier permeability with binary labels, indicating a drug's ability to penetrate the barrier to reach the brain",
 
35
  "num_molecules": 1915
36
  },
37
  "ADMET_PPBR_AZ": {
38
  "task_type": "regression",
 
39
  "description": "predict plasma protein binding rate with continuous labels, indicating the percentage of a drug bound to plasma proteins in the blood",
 
40
  "num_molecules": 1797
41
  },
42
  "ADMET_VDss_Lombardo": {
43
  "task_type": "regression",
 
44
  "description": "predict the volume of distribution at steady state (VDss), indicating drug concentration in tissues versus blood",
 
45
  "num_molecules": 1130
46
  },
47
  "ADMET_CYP2C9_Veith": {
48
  "task_type": "classification",
 
49
  "description": "predict CYP2C9 inhibition with binary labels, indicating the drug's ability to inhibit the CYP2C9 enzyme involved in metabolism",
 
50
  "num_molecules": 12092
51
  },
52
  "ADMET_CYP2D6_Veith": {
53
  "task_type": "classification",
 
54
  "description": "predict CYP2D6 inhibition with binary labels, indicating the drug's potential to inhibit the CYP2D6 enzyme involved in metabolism",
 
55
  "num_molecules": 13130
56
  },
57
  "ADMET_CYP3A4_Veith": {
58
  "task_type": "classification",
 
59
  "description": "predict CPY3A4 inhibition with binary labels, indicating the drug's ability to inhibit the CPY3A4 enzyme involved in metabolism",
 
60
  "num_molecules": 12328
61
  },
62
  "ADMET_CYP2C9_Substrate_CarbonMangels": {
63
  "task_type": "classification",
 
64
  "description": "predict whether a drug is a substrate of the CYP2C9 enzyme with binary labels, indicating its potential to be metabolized",
 
65
  "num_molecules": 666
66
  },
67
  "ADMET_CYP2D6_Substrate_CarbonMangels": {
68
  "task_type": "classification",
 
69
  "description": "predict whether a drug is a substrate of the CYP2D6 enzyme with binary labels, indicating its potential to be metabolized",
 
70
  "num_molecules": 664
71
  },
72
  "ADMET_CYP3A4_Substrate_CarbonMangels": {
73
  "task_type": "classification",
 
74
  "description": "predict whether a drug is a substrate of the CYP3A4 enzyme with binary labels, indicating its potential to be metabolized",
 
75
  "num_molecules": 667
76
  },
77
  "ADMET_Half_Life_Obach": {
78
  "task_type": "regression",
 
79
  "description": "predict the half-life duration of a drug, measured in hours, indicating the time for its concentration to reduce by half",
 
80
  "num_molecules": 667
81
  },
82
  "ADMET_Clearance_Hepatocyte_AZ": {
83
  "task_type": "regression",
 
84
  "description": "predict drug clearance, measured in \u03bcL/min/10^6 cells, from hepatocyte experiments, indicating the rate at which the drug is removed from body",
 
85
  "num_molecules": 1020
86
  },
87
  "ADMET_Clearance_Microsome_AZ": {
88
  "task_type": "regression",
 
89
  "description": "predict drug clearance, measured in mL/min/g, from microsome experiments, indicating the rate at which the drug is removed from body",
 
90
  "num_molecules": 1102
91
  },
92
  "ADMET_LD50_Zhu": {
93
  "task_type": "regression",
 
94
  "description": "predict the acute toxicity of a drug, measured as the dose leading to lethal effects in log(kg/mol)",
 
95
  "num_molecules": 7385
96
  },
97
  "ADMET_hERG": {
98
  "task_type": "classification",
 
99
  "description": "predict whether a drug blocks the hERG channel, which is crucial for heart rhythm, potentially leading to adverse effects",
 
100
  "num_molecules": 648
101
  },
102
  "ADMET_AMES": {
103
  "task_type": "classification",
 
104
  "description": "predict whether a drug is mutagenic with binary labels, indicating its ability to induce genetic alterations",
 
105
  "num_molecules": 7255
106
  },
107
  "ADMET_DILI": {
108
  "task_type": "classification",
 
109
  "description": "predict whether a drug can cause liver injury with binary labels, indicating its potential for hepatotoxicity",
 
110
  "num_molecules": 475
111
  }
112
  }
 
1
  {
2
  "ADMET_Caco2_Wang": {
3
  "task_type": "regression",
4
+ "task_name": "Drug Permeability",
5
  "description": "predict drug permeability, measured in cm/s, using the Caco-2 cell line as an in vitro model to simulate human intestinal tissue permeability",
6
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#caco-2-cell-effective-permeability-wang-et-al",
7
  "num_molecules": 906
8
  },
9
  "ADMET_Bioavailability_Ma": {
10
  "task_type": "classification",
11
+ "task_name": "Drug Oral Bioavailability",
12
  "description": "predict oral bioavailability with binary labels, indicating the rate and extent a drug becomes available at its site of action",
13
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#bioavailability-ma-et-al",
14
  "num_molecules": 640
15
  },
16
  "ADMET_Lipophilicity_AstraZeneca": {
17
  "task_type": "regression",
18
+ "task_name": "Drug Lipophilicity",
19
  "description": "predict lipophilicity with continuous labels, measured as a log-ratio, indicating a drug's ability to dissolve in lipid environments",
20
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#lipophilicity-astrazeneca",
21
  "num_molecules": 4200
22
  },
23
  "ADMET_Solubility_AqSolDB": {
24
  "task_type": "regression",
25
+ "task_name": "Drug Aqueous Solubility",
26
  "description": "predict aqueous solubility with continuous labels, measured in log mol/L, indicating a drug's ability to dissolve in water",
27
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#solubility-aqsoldb",
28
  "num_molecules": 9982
29
  },
30
  "ADMET_HIA_Hou": {
31
  "task_type": "classification",
32
+ "task_name": "Drug Human Intestinal Absorption",
33
  "description": "predict human intestinal absorption (HIA) with binary labels, indicating a drug's ability to be absorbed into the bloodstream",
34
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#hia-human-intestinal-absorption-hou-et-al",
35
  "num_molecules": 578
36
  },
37
  "ADMET_Pgp_Broccatelli": {
38
  "task_type": "classification",
39
+ "task_name": "P-glycoprotein Inhibition",
40
  "description": "predict P-glycoprotein (Pgp) inhibition with binary labels, indicating a drug's potential to alter bioavailability and overcome multidrug resistance",
41
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#pgp-p-glycoprotein-inhibition-broccatelli-et-al",
42
  "num_molecules": 1212
43
  },
44
  "ADMET_BBB_Martins": {
45
  "task_type": "classification",
46
+ "task_name": "Blood-Brain Barrier Permeability",
47
  "description": "predict blood-brain barrier permeability with binary labels, indicating a drug's ability to penetrate the barrier to reach the brain",
48
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#bbb-blood-brain-barrier-martins-et-al",
49
  "num_molecules": 1915
50
  },
51
  "ADMET_PPBR_AZ": {
52
  "task_type": "regression",
53
+ "task_name": "Plasma Protein Binding Rate",
54
  "description": "predict plasma protein binding rate with continuous labels, indicating the percentage of a drug bound to plasma proteins in the blood",
55
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#ppbr-plasma-protein-binding-rate-astrazeneca",
56
  "num_molecules": 1797
57
  },
58
  "ADMET_VDss_Lombardo": {
59
  "task_type": "regression",
60
+ "task_name": "Volume of Distribution at Steady State",
61
  "description": "predict the volume of distribution at steady state (VDss), indicating drug concentration in tissues versus blood",
62
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#vdss-volumn-of-distribution-at-steady-state-lombardo-et-al",
63
  "num_molecules": 1130
64
  },
65
  "ADMET_CYP2C9_Veith": {
66
  "task_type": "classification",
67
+ "task_name": "CYP2C9 Inhibition",
68
  "description": "predict CYP2C9 inhibition with binary labels, indicating the drug's ability to inhibit the CYP2C9 enzyme involved in metabolism",
69
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#cyp-p450-2c9-inhibition-veith-et-al",
70
  "num_molecules": 12092
71
  },
72
  "ADMET_CYP2D6_Veith": {
73
  "task_type": "classification",
74
+ "task_name": "CYP2D6 Inhibition",
75
  "description": "predict CYP2D6 inhibition with binary labels, indicating the drug's potential to inhibit the CYP2D6 enzyme involved in metabolism",
76
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#cyp-p450-2d6-inhibition-veith-et-al",
77
  "num_molecules": 13130
78
  },
79
  "ADMET_CYP3A4_Veith": {
80
  "task_type": "classification",
81
+ "task_name": "CPY3A4 Inhibition",
82
  "description": "predict CPY3A4 inhibition with binary labels, indicating the drug's ability to inhibit the CPY3A4 enzyme involved in metabolism",
83
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#cyp-p450-3a4-inhibition-veith-et-al",
84
  "num_molecules": 12328
85
  },
86
  "ADMET_CYP2C9_Substrate_CarbonMangels": {
87
  "task_type": "classification",
88
+ "task_name": "CYP2C9 Substrate",
89
  "description": "predict whether a drug is a substrate of the CYP2C9 enzyme with binary labels, indicating its potential to be metabolized",
90
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#cyp2c9-substrate-carbon-mangels-et-al",
91
  "num_molecules": 666
92
  },
93
  "ADMET_CYP2D6_Substrate_CarbonMangels": {
94
  "task_type": "classification",
95
+ "task_name": "CYP2D6 Substrate",
96
  "description": "predict whether a drug is a substrate of the CYP2D6 enzyme with binary labels, indicating its potential to be metabolized",
97
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#cyp2d6-substrate-carbon-mangels-et-al",
98
  "num_molecules": 664
99
  },
100
  "ADMET_CYP3A4_Substrate_CarbonMangels": {
101
  "task_type": "classification",
102
+ "task_name": "CYP3A4 Substrate",
103
  "description": "predict whether a drug is a substrate of the CYP3A4 enzyme with binary labels, indicating its potential to be metabolized",
104
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#cyp3a4-substrate-carbon-mangels-et-al",
105
  "num_molecules": 667
106
  },
107
  "ADMET_Half_Life_Obach": {
108
  "task_type": "regression",
109
+ "task_name": "Drug Half-Life Duration",
110
  "description": "predict the half-life duration of a drug, measured in hours, indicating the time for its concentration to reduce by half",
111
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#half-life-obach-et-al",
112
  "num_molecules": 667
113
  },
114
  "ADMET_Clearance_Hepatocyte_AZ": {
115
  "task_type": "regression",
116
+ "task_name": "Drug Clearance from Hepatocyte Experiments",
117
  "description": "predict drug clearance, measured in \u03bcL/min/10^6 cells, from hepatocyte experiments, indicating the rate at which the drug is removed from body",
118
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#clearance-astrazeneca",
119
  "num_molecules": 1020
120
  },
121
  "ADMET_Clearance_Microsome_AZ": {
122
  "task_type": "regression",
123
+ "task_name": "Drug Clearance from Microsome Experiments",
124
  "description": "predict drug clearance, measured in mL/min/g, from microsome experiments, indicating the rate at which the drug is removed from body",
125
+ "url": "https://tdcommons.ai/single_pred_tasks/adme#clearance-astrazeneca",
126
  "num_molecules": 1102
127
  },
128
  "ADMET_LD50_Zhu": {
129
  "task_type": "regression",
130
+ "task_name": "Drug Acute Toxicity",
131
  "description": "predict the acute toxicity of a drug, measured as the dose leading to lethal effects in log(kg/mol)",
132
+ "url": "https://tdcommons.ai/single_pred_tasks/tox#acute-toxicity-ld50",
133
  "num_molecules": 7385
134
  },
135
  "ADMET_hERG": {
136
  "task_type": "classification",
137
+ "task_name": "hERG Channel Blockage",
138
  "description": "predict whether a drug blocks the hERG channel, which is crucial for heart rhythm, potentially leading to adverse effects",
139
+ "url": "https://tdcommons.ai/single_pred_tasks/tox#herg-blockers",
140
  "num_molecules": 648
141
  },
142
  "ADMET_AMES": {
143
  "task_type": "classification",
144
+ "task_name": "Drug Mutagenicity",
145
  "description": "predict whether a drug is mutagenic with binary labels, indicating its ability to induce genetic alterations",
146
+ "url": "https://tdcommons.ai/single_pred_tasks/tox#ames-mutagenicity",
147
  "num_molecules": 7255
148
  },
149
  "ADMET_DILI": {
150
  "task_type": "classification",
151
+ "task_name": "Drug-Induced Liver Injury",
152
  "description": "predict whether a drug can cause liver injury with binary labels, indicating its potential for hepatotoxicity",
153
+ "url": "https://tdcommons.ai/single_pred_tasks/tox#dili-drug-induced-liver-injury",
154
  "num_molecules": 475
155
  }
156
  }
utils.py CHANGED
@@ -39,22 +39,30 @@ from rdkit import RDLogger, Chem
39
  RDLogger.DisableLog('rdApp.*')
40
 
41
  # we have a dictionary to store the task types of the models
42
- task_types = {
43
- "admet_ppbr_az": "regression",
44
- "admet_half_life_obach": "regression",
45
- }
 
46
 
47
  # read the dataset descriptions
48
  with open("dataset_descriptions.json", "r") as f:
49
  dataset_description_temp = json.load(f)
50
 
51
  dataset_descriptions = dict()
 
 
 
52
 
53
  for dataset in dataset_description_temp:
54
  dataset_name = dataset.lower()
55
  dataset_descriptions[dataset_name] = \
56
- f"{dataset_name} is a {dataset_description_temp[dataset]['task_type']} task, " + \
57
- f"where the goal is to {dataset_description_temp[dataset]['description']}."
 
 
 
 
58
 
59
  class Scaler:
60
  def __init__(self, log=False):
@@ -215,7 +223,11 @@ class MolecularPropertyPredictionModel():
215
  adapter_id = candidate_models[adapter_name]
216
  print(f"loading {adapter_name} from {adapter_id}...")
217
  self.base_model.load_adapter(adapter_id, adapter_name=adapter_name, token = os.environ.get("TOKEN"))
218
- self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
 
 
 
 
219
 
220
  #self.base_model.to("cuda")
221
  #print(self.base_model)
@@ -242,7 +254,7 @@ class MolecularPropertyPredictionModel():
242
 
243
  #if adapter_name not in self.apapter_scaler_path:
244
  # self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
245
- if os.path.exists(self.apapter_scaler_path[adapter_name]):
246
  self.scaler = pickle.load(open(self.apapter_scaler_path[adapter_name], "rb"))
247
  else:
248
  self.scaler = None
@@ -276,7 +288,7 @@ class MolecularPropertyPredictionModel():
276
  if task_type == "regression": # TODO: check if the model is regression or classification
277
  y_pred.append(outputs.logits.cpu().detach().numpy())
278
  else:
279
- y_pred.append((torch.sigmoid(outputs.logits) > 0.5).cpu().detach().numpy())
280
 
281
  y_pred = np.concatenate(y_pred, axis=0)
282
  if task_type=="regression" and self.scaler is not None:
 
39
  RDLogger.DisableLog('rdApp.*')
40
 
41
  # we have a dictionary to store the task types of the models
42
+ #task_types = {
43
+ # "admet_bioavailability_ma": "classification",
44
+ # "admet_ppbr_az": "regression",
45
+ # "admet_half_life_obach": "regression",
46
+ #}
47
 
48
  # read the dataset descriptions
49
  with open("dataset_descriptions.json", "r") as f:
50
  dataset_description_temp = json.load(f)
51
 
52
  dataset_descriptions = dict()
53
+ dataset_property_names = dict()
54
+ dataset_task_types = dict()
55
+ dataset_property_names_to_dataset = dict()
56
 
57
  for dataset in dataset_description_temp:
58
  dataset_name = dataset.lower()
59
  dataset_descriptions[dataset_name] = \
60
+ f"{dataset_description_temp[dataset]['task_name']} is a {dataset_description_temp[dataset]['task_type']} task, " + \
61
+ f"where the goal is to {dataset_description_temp[dataset]['description']}. \n" + \
62
+ f"More information can be found at {dataset_description_temp[dataset]['url']}."
63
+ dataset_property_names[dataset_name] = dataset_description_temp[dataset]['task_name']
64
+ dataset_property_names_to_dataset[dataset_description_temp[dataset]['task_name']] = dataset_name
65
+ dataset_task_types[dataset_name] = dataset_description_temp[dataset]['task_type']
66
 
67
  class Scaler:
68
  def __init__(self, log=False):
 
223
  adapter_id = candidate_models[adapter_name]
224
  print(f"loading {adapter_name} from {adapter_id}...")
225
  self.base_model.load_adapter(adapter_id, adapter_name=adapter_name, token = os.environ.get("TOKEN"))
226
+ try:
227
+ self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
228
+ except:
229
+ self.apapter_scaler_path[adapter_name] = None
230
+ assert dataset_task_types[adapter_name] == "classification", f"{adapter_name} is not a regression task."
231
 
232
  #self.base_model.to("cuda")
233
  #print(self.base_model)
 
254
 
255
  #if adapter_name not in self.apapter_scaler_path:
256
  # self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
257
+ if self.apapter_scaler_path[adapter_name] and os.path.exists(self.apapter_scaler_path[adapter_name]):
258
  self.scaler = pickle.load(open(self.apapter_scaler_path[adapter_name], "rb"))
259
  else:
260
  self.scaler = None
 
288
  if task_type == "regression": # TODO: check if the model is regression or classification
289
  y_pred.append(outputs.logits.cpu().detach().numpy())
290
  else:
291
+ y_pred.append((torch.sigmoid(outputs.logits)).cpu().detach().numpy())
292
 
293
  y_pred = np.concatenate(y_pred, axis=0)
294
  if task_type=="regression" and self.scaler is not None: