diff --git a/configs/callbacks/csv_prediction_writer.yaml b/configs/callbacks/csv_prediction_writer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7fc6fb93940f0910938553caec31c02a3aeb60b --- /dev/null +++ b/configs/callbacks/csv_prediction_writer.yaml @@ -0,0 +1,4 @@ +csv_prediction_writer: + _target_: deepscreen.utils.lightning.CSVPredictionWriter + output_dir: ${paths.output_dir} + write_interval: batch diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f94c639e898f22416a21af826e6453b99fbf01f1 --- /dev/null +++ b/configs/callbacks/default.yaml @@ -0,0 +1,5 @@ +defaults: + - model_checkpoint + - early_stopping + - model_summary + - rich_progress_bar diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d2bb37ad3f056b1fb6cff81f162f7222a3c5a3e --- /dev/null +++ b/configs/callbacks/early_stopping.yaml @@ -0,0 +1,17 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html + +# Monitor a metric and stop training when it stops improving. +# Look at the above link for more detailed information. +early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: ${oc.select:callbacks.model_checkpoint.monitor,"val/loss"} # quantity to be monitored, must be specified!!! + min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement + patience: 50 # number of checks with no improvement after which training will be stopped + verbose: False # verbosity mode + mode: ${callbacks.model_checkpoint.mode} # "max" means higher metric value is better, can be also "min" + strict: True # whether to crash the training if monitor is not found in the validation metrics + check_finite: True # when set True, stops training when the monitor becomes NaN or infinite + stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold + divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold + check_on_train_epoch_end: False # whether to run early stopping at the end of the training epoch + log_rank_zero_only: False # logs the status of the early stopping callback only for rank 0 process diff --git a/configs/callbacks/inference.yaml b/configs/callbacks/inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22fe5db38092387f72718e3046e86f57663e8c00 --- /dev/null +++ b/configs/callbacks/inference.yaml @@ -0,0 +1,6 @@ +defaults: + - model_summary + - rich_progress_bar + +model_summary: + max_depth: 2 diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8e51b2bb6036ce4a0b1ce3dc47c2869514b2eeb --- /dev/null +++ b/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,19 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html + +# Save the model periodically by monitoring a quantity. +# Look at the above link for more detailed information. +model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.output_dir} # directory to save the model file + filename: "checkpoints/epoch_{epoch:03d}" # checkpoint filename + monitor: ${eval:'"val/loss" if ${data.train_val_test_split}[1] else "train/loss"'} # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: True # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 1 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: False # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: null # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f854fd9fd17863a8c5930f482dc0d235cfc03698 --- /dev/null +++ b/configs/callbacks/model_summary.yaml @@ -0,0 +1,7 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html + +# Generates a summary of all layers in a LightningModule with rich text formatting. +# Look at the above link for more detailed information. +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 2 # The maximum depth of layer nesting that the summary will include. `-1` for all modules `0` for none. diff --git a/configs/callbacks/none.yaml b/configs/callbacks/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..82d2f89d28be7511f2360e1b047656afeee16141 --- /dev/null +++ b/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,6 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html + +# Create a progress bar with rich text formatting. +# Look at the above link for more detailed information. +rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/configs/callbacks/tqdm_progress_bar.yaml b/configs/callbacks/tqdm_progress_bar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f8653fcde6f4c1762dc82dbec824cd19ab097cf --- /dev/null +++ b/configs/callbacks/tqdm_progress_bar.yaml @@ -0,0 +1,2 @@ +tqdm_progress_bar: + _target_: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/data/collator/default.yaml b/configs/data/collator/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..514c0837c358aeb82d22ec2c34a7226d05580423 --- /dev/null +++ b/configs/data/collator/default.yaml @@ -0,0 +1,5 @@ +_target_: deepscreen.data.utils.collator.collate_fn +_partial_: true + +automatic_padding: false +padding_value: 0.0 diff --git a/configs/data/collator/none.yaml b/configs/data/collator/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29899b9fc6b0164c149e17b31e350be9135add49 --- /dev/null +++ b/configs/data/collator/none.yaml @@ -0,0 +1,2 @@ +_target_: deepscreen.utils.passthrough +_partial_: true \ No newline at end of file diff --git a/configs/data/drug_featurizer/ecfp.yaml b/configs/data/drug_featurizer/ecfp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a341607f3f242e6f993bf758e7341b376f8b4ef3 --- /dev/null +++ b/configs/data/drug_featurizer/ecfp.yaml @@ -0,0 +1,6 @@ +_target_: deepscreen.data.featurizers.fingerprint.smiles_to_fingerprint +_partial_: true + +fingerprint: MorganFP +nBits: 1024 +radius: 2 \ No newline at end of file diff --git a/configs/data/drug_featurizer/fcs.yaml b/configs/data/drug_featurizer/fcs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..77a0cdbec7168829ce846fa28543707a000d0c0d --- /dev/null +++ b/configs/data/drug_featurizer/fcs.yaml @@ -0,0 +1,4 @@ +_target_: deepscreen.data.featurizers.fcs.drug_to_embedding +_partial_: true + +max_sequence_length: 205 \ No newline at end of file diff --git a/configs/data/drug_featurizer/graph.yaml b/configs/data/drug_featurizer/graph.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d3d65a1fbec937c9484584d8f9f89bbc999539f9 --- /dev/null +++ b/configs/data/drug_featurizer/graph.yaml @@ -0,0 +1,2 @@ +_target_: deepscreen.data.featurizers.graph.smiles_to_graph +_partial_: true \ No newline at end of file diff --git a/configs/data/drug_featurizer/label.yaml b/configs/data/drug_featurizer/label.yaml new file mode 100644 index 0000000000000000000000000000000000000000..80ffe2fbdf7038953a8d94e4fa9def36196d66bb --- /dev/null +++ b/configs/data/drug_featurizer/label.yaml @@ -0,0 +1,15 @@ +#_target_: deepscreen.data.featurizers.categorical.smiles_to_label +#_partial_: true +# +#max_sequence_length: 100 +##in_channels: 63 + +_target_: deepscreen.data.featurizers.categorical.sequence_to_label +_partial_: true +charset: ['#', '%', ')', '(', '+', '-', '.', '1', '0', '3', '2', '5', '4', + '7', '6', '9', '8', '=', 'A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', + 'H', 'K', 'M', 'L', 'O', 'N', 'P', 'S', 'R', 'U', 'T', 'W', 'V', + 'Y', '[', 'Z', ']', '_', 'a', 'c', 'b', 'e', 'd', 'g', 'f', 'i', + 'h', 'm', 'l', 'o', 'n', 's', 'r', 'u', 't', 'y'] + +max_sequence_length: 100 \ No newline at end of file diff --git a/configs/data/drug_featurizer/mol_features.yaml b/configs/data/drug_featurizer/mol_features.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b0652eace74a5041d7bfa166152e347b9f009c64 --- /dev/null +++ b/configs/data/drug_featurizer/mol_features.yaml @@ -0,0 +1,4 @@ +_target_: deepscreen.data.featurizers.graph.smiles_to_mol_features +_partial_: true + +num_atom_feat: 34 diff --git a/configs/data/drug_featurizer/none.yaml b/configs/data/drug_featurizer/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29899b9fc6b0164c149e17b31e350be9135add49 --- /dev/null +++ b/configs/data/drug_featurizer/none.yaml @@ -0,0 +1,2 @@ +_target_: deepscreen.utils.passthrough +_partial_: true \ No newline at end of file diff --git a/configs/data/drug_featurizer/onehot.yaml b/configs/data/drug_featurizer/onehot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..827114b5c04984b2bbd78697dda712a109db3514 --- /dev/null +++ b/configs/data/drug_featurizer/onehot.yaml @@ -0,0 +1,15 @@ +#_target_: deepscreen.data.featurizers.categorical.smiles_to_onehot +#_partial_: true +# +#max_sequence_length: 100 +##in_channels: 63 + +_target_: deepscreen.data.featurizers.categorical.sequence_to_onehot +_partial_: true +charset: ['#', '%', ')', '(', '+', '-', '.', '1', '0', '3', '2', '5', '4', + '7', '6', '9', '8', '=', 'A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', + 'H', 'K', 'M', 'L', 'O', 'N', 'P', 'S', 'R', 'U', 'T', 'W', 'V', + 'Y', '[', 'Z', ']', '_', 'a', 'c', 'b', 'e', 'd', 'g', 'f', 'i', + 'h', 'm', 'l', 'o', 'n', 's', 'r', 'u', 't', 'y'] + +max_sequence_length: 100 \ No newline at end of file diff --git a/configs/data/drug_featurizer/tokenizer.yaml b/configs/data/drug_featurizer/tokenizer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..369aa42da9f2a5f3990dff0097afd0fe8cf5da55 --- /dev/null +++ b/configs/data/drug_featurizer/tokenizer.yaml @@ -0,0 +1,6 @@ +_target_: deepscreen.data.featurizers.token.sequence_to_token_ids +_partial_: true + +tokenizer: + _target_: deepscreen.data.featurizers.token.SmilesTokenizer + vocab_file: resources/vocabs/smiles.txt diff --git a/configs/data/dti.yaml.bak b/configs/data/dti.yaml.bak new file mode 100644 index 0000000000000000000000000000000000000000..93682ed7b88a14ac5e7afecb4327aa50d51d29b8 --- /dev/null +++ b/configs/data/dti.yaml.bak @@ -0,0 +1,21 @@ +_target_: deepscreen.data.dti_datamodule.DTIdatamodule + +defaults: + - _self_ + - split: null + - drug_featurizer: null + - protein_featurizer: null + +task: ${task.task} +n_class: ${oc.select:task.task.n_class,null} + +data_dir: ${paths.data_dir} +dataset_name: null + +batch_size: 16 +train_val_test_split: [0.7, 0.1, 0.2] + +num_workers: 0 +pin_memory: false + +train: ${train} \ No newline at end of file diff --git a/configs/data/dti_data.yaml b/configs/data/dti_data.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5b65659c15f8fac1a229e16586ffb226293b4f14 --- /dev/null +++ b/configs/data/dti_data.yaml @@ -0,0 +1,20 @@ +_target_: deepscreen.data.dti.DTIDataModule + +defaults: + - split: null + - drug_featurizer: none # ??? + - protein_featurizer: none # ??? + - collator: default + +task: ${task.task} +num_classes: ${task.num_classes} + +data_dir: ${paths.data_dir} +data_file: null +train_val_test_split: null + +batch_size: ??? +num_workers: 0 +pin_memory: false + +#train: ${train} \ No newline at end of file diff --git a/configs/data/protein_featurizer/fcs.yaml b/configs/data/protein_featurizer/fcs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cbf1f380d2bd951a321971676cb20f7de5e17e7c --- /dev/null +++ b/configs/data/protein_featurizer/fcs.yaml @@ -0,0 +1,4 @@ +_target_: deepscreen.data.featurizers.fcs.protein_to_embedding +_partial_: true + +max_sequence_length: 545 \ No newline at end of file diff --git a/configs/data/protein_featurizer/label.yaml b/configs/data/protein_featurizer/label.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1824297352e8d0706126486f1a00585e78c4f892 --- /dev/null +++ b/configs/data/protein_featurizer/label.yaml @@ -0,0 +1,12 @@ +#_target_: deepscreen.data.featurizers.categorical.fasta_to_label +#_partial_: true +# +#max_sequence_length: 1000 +##in_channels: 26 + +_target_: deepscreen.data.featurizers.categorical.sequence_to_label +_partial_: true +charset: ['A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', 'H', 'K', 'M', 'L', 'O', + 'N', 'Q', 'P', 'S', 'R', 'U', 'T', 'W', 'V', 'Y', 'X', 'Z'] + +max_sequence_length: 1000 \ No newline at end of file diff --git a/configs/data/protein_featurizer/none.yaml b/configs/data/protein_featurizer/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29899b9fc6b0164c149e17b31e350be9135add49 --- /dev/null +++ b/configs/data/protein_featurizer/none.yaml @@ -0,0 +1,2 @@ +_target_: deepscreen.utils.passthrough +_partial_: true \ No newline at end of file diff --git a/configs/data/protein_featurizer/onehot.yaml b/configs/data/protein_featurizer/onehot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d44e8a5fcc0ac0cfd47e17fda8f302e73bcc359d --- /dev/null +++ b/configs/data/protein_featurizer/onehot.yaml @@ -0,0 +1,12 @@ +#_target_: deepscreen.data.featurizers.categorical.fasta_to_onehot +#_partial_: true +# +#max_sequence_length: 1000 +##in_channels: 26 + +_target_: deepscreen.data.featurizers.categorical.sequence_to_onehot +_partial_: true +charset: ['A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', 'H', 'K', 'M', 'L', 'O', + 'N', 'Q', 'P', 'S', 'R', 'U', 'T', 'W', 'V', 'Y', 'X', 'Z'] + +max_sequence_length: 1000 \ No newline at end of file diff --git a/configs/data/protein_featurizer/tokenizer.yaml b/configs/data/protein_featurizer/tokenizer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e264020782f84719e5370dd3cfc10f44a7a236c --- /dev/null +++ b/configs/data/protein_featurizer/tokenizer.yaml @@ -0,0 +1,6 @@ +_target_: deepscreen.data.featurizers.token.sequence_to_token_ids +_partial_: true + +tokenizer: + _target_: tape.TAPETokenizer.from_pretrained + vocab: iupac diff --git a/configs/data/protein_featurizer/word2vec.yaml b/configs/data/protein_featurizer/word2vec.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d9f4761a374a05ce0f5b20cd80073632f0c030f --- /dev/null +++ b/configs/data/protein_featurizer/word2vec.yaml @@ -0,0 +1,6 @@ +_target_: deepscreen.data.featurizers.word.protein_to_word_embedding +_partial_: true + +model: + _target_: gensim.models.Word2Vec.load + fname: ${paths.resource_dir}/models/word2vec_30.model \ No newline at end of file diff --git a/configs/data/split/cold_drug.yaml b/configs/data/split/cold_drug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c0416a876682bb9759822d493e70a21ac0f28f6 --- /dev/null +++ b/configs/data/split/cold_drug.yaml @@ -0,0 +1,4 @@ +_target_: deepscreen.data.utils.split.cold_start +_partial_: true + +entity: drug \ No newline at end of file diff --git a/configs/data/split/cold_protein.yaml b/configs/data/split/cold_protein.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a3694841cc5cb00549d3563791a872781d0d775 --- /dev/null +++ b/configs/data/split/cold_protein.yaml @@ -0,0 +1,4 @@ +_target_: deepscreen.data.utils.split.cold_start +_partial_: true + +entity: protein \ No newline at end of file diff --git a/configs/data/split/none.yaml b/configs/data/split/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/data/split/random.yaml b/configs/data/split/random.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8c00c9b5b38f57fda3b03185a1e91dbefedff8c --- /dev/null +++ b/configs/data/split/random.yaml @@ -0,0 +1,10 @@ +#_target_: torch.utils.data.random_split +#_partial_: true + +#generator: +# _target_: torch.Generator # will use global seed set by lightning.seed_everything or torch.manual_seed automatically + +_target_: deepscreen.data.utils.split.random_split +_partial_: true + +seed: ${seed} diff --git a/configs/data/transform/minmax.yaml b/configs/data/transform/minmax.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c4e68561d5768d6b1677900484c386ce3b8df7c --- /dev/null +++ b/configs/data/transform/minmax.yaml @@ -0,0 +1,5 @@ +_target_: deepscreen.data.utils.transform +_partial_: true + +scaler: + _target_: sklearn.preprocessing.MinMaxScaler diff --git a/configs/data/transform/none.yaml b/configs/data/transform/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29899b9fc6b0164c149e17b31e350be9135add49 --- /dev/null +++ b/configs/data/transform/none.yaml @@ -0,0 +1,2 @@ +_target_: deepscreen.utils.passthrough +_partial_: true \ No newline at end of file diff --git a/configs/debug/advanced.yaml b/configs/debug/advanced.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c249cc45699de1d1d8f25aa7c873b127e3df6f9 --- /dev/null +++ b/configs/debug/advanced.yaml @@ -0,0 +1,25 @@ +# @package _global_ + +# advanced debug mode that enables callbacks, loggers and gpu during debugging +job_name: "debug" + +extras: + ignore_warnings: False + enforce_tags: False + +hydra: + job_logging: + root: + level: DEBUG + verbose: True + +trainer: + max_epochs: 1 + accelerator: gpu + devices: 1 + detect_anomaly: true + deterministic: false + +data: + num_workers: 0 + pin_memory: False diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..84e7dc1e80b355c8ab075a28e65d8fae08127eea --- /dev/null +++ b/configs/debug/default.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# default debugging setup, runs 1 full epoch +# other debugging configs can inherit from this one + +# overwrite job name so debugging logs are stored in separate folder +job_name: "debug" + +# disable callbacks and loggers during debugging +callbacks: null +logger: null + +extras: + ignore_warnings: False + enforce_tags: False + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + job_logging: + root: + level: DEBUG + # use this to also set hydra loggers to 'DEBUG' + verbose: True + +trainer: + max_epochs: 1 + accelerator: cpu # debuggers don't like gpus + devices: 1 # debuggers don't like multiprocessing + detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor + deterministic: false + +data: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: False # disable gpu memory pin diff --git a/configs/debug/fdr.yaml b/configs/debug/fdr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..111e2f35924bf77db6706867a5a381cb90e2e855 --- /dev/null +++ b/configs/debug/fdr.yaml @@ -0,0 +1,11 @@ +# @package _global_ + +# runs 1 train, 1 validation and 1 test step + +defaults: + - default + +trainer: + accelerator: gpu + fast_dev_run: true + detect_anomaly: true diff --git a/configs/debug/fdr_advanced.yaml b/configs/debug/fdr_advanced.yaml new file mode 100644 index 0000000000000000000000000000000000000000..570718adfc2148760067ac799cdc3f4ae38f6c78 --- /dev/null +++ b/configs/debug/fdr_advanced.yaml @@ -0,0 +1,11 @@ +# @package _global_ + +# runs 1 train, 1 validation and 1 test step + +defaults: + - advanced + +trainer: + accelerator: gpu + fast_dev_run: true + detect_anomaly: true \ No newline at end of file diff --git a/configs/debug/limit.yaml b/configs/debug/limit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..514d77fbd1475b03fff0372e3da3c2fa7ea7d190 --- /dev/null +++ b/configs/debug/limit.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# uses only 1% of the training data and 5% of validation/test data + +defaults: + - default + +trainer: + max_epochs: 3 + limit_train_batches: 0.01 + limit_val_batches: 0.05 + limit_test_batches: 0.05 diff --git a/configs/debug/overfit.yaml b/configs/debug/overfit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9906586a67a12aa81ff69138f589a366dbe2222f --- /dev/null +++ b/configs/debug/overfit.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# overfits to 3 batches + +defaults: + - default + +trainer: + max_epochs: 20 + overfit_batches: 3 + +# model ckpt and early stopping need to be disabled during overfitting +callbacks: null diff --git a/configs/debug/profiler.yaml b/configs/debug/profiler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2bd7da87ae23ed425ace99b09250a76a5634a3fb --- /dev/null +++ b/configs/debug/profiler.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# runs with execution time profiling + +defaults: + - default + +trainer: + max_epochs: 1 + profiler: "simple" + # profiler: "advanced" + # profiler: "pytorch" diff --git a/configs/experiment/bindingdb.yaml b/configs/experiment/bindingdb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3e170d30a9bbdd3d0252b43979b2861aeb39ce1 --- /dev/null +++ b/configs/experiment/bindingdb.yaml @@ -0,0 +1,9 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + +data: + train_val_test_split: [dti_benchmark/random_split_update/bindingdb_train.csv, + dti_benchmark/random_split_update/bindingdb_valid.csv, + dti_benchmark/random_split_update/bindingdb_test.csv] diff --git a/configs/experiment/chembl_random.yaml b/configs/experiment/chembl_random.yaml new file mode 100644 index 0000000000000000000000000000000000000000..01b9e8aca42b0925a22303d9f2c3f8c4f5998340 --- /dev/null +++ b/configs/experiment/chembl_random.yaml @@ -0,0 +1,9 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + +data: + train_val_test_split: [chembl_random_global_balance_1_train.csv, + chembl_random_global_balance_1_valid.csv, + chembl_random_global_balance_1_test.csv] diff --git a/configs/experiment/chembl_rmfh_random.yaml b/configs/experiment/chembl_rmfh_random.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a15e8ebfac7db334dc82d22adae2181c1473d9c8 --- /dev/null +++ b/configs/experiment/chembl_rmfh_random.yaml @@ -0,0 +1,9 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + +data: + train_val_test_split: [chembl_rmFH_random_global_balance_1_train.csv, + chembl_rmFH_random_global_balance_1_valid.csv, + chembl_rmFH_random_global_balance_1_test.csv] \ No newline at end of file diff --git a/configs/experiment/davis.yaml b/configs/experiment/davis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..95c628e21ba17d808363a8f89f36ad0513e6a1de --- /dev/null +++ b/configs/experiment/davis.yaml @@ -0,0 +1,9 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + +data: + train_val_test_split: [dti_benchmark/random_split_update/davis_train.csv, + dti_benchmark/random_split_update/davis_valid.csv, + dti_benchmark/random_split_update/davis_test.csv] diff --git a/configs/experiment/demo_bindingdb.yaml b/configs/experiment/demo_bindingdb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0ccc1bcdfd5a12025a0ff6aa591a6989f3db2982 --- /dev/null +++ b/configs/experiment/demo_bindingdb.yaml @@ -0,0 +1,9 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + - override /data/split: random + +data: + data_file: demo/binddb_ic50_demo.csv + train_val_test_split: [0.7, 0.1, 0.2] diff --git a/configs/experiment/dti_experiment.yaml b/configs/experiment/dti_experiment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..df70106572153a3444128f46023e58cdb6001a04 --- /dev/null +++ b/configs/experiment/dti_experiment.yaml @@ -0,0 +1,19 @@ +# @package _global_ +defaults: + - override /data: dti_data + - override /model: dti_model + - override /trainer: gpu + +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 500 + precision: 16-mixed + +callbacks: + early_stopping: + patience: 50 + +data: + num_workers: 8 diff --git a/configs/experiment/example.yaml b/configs/experiment/example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..03babeb44c0eec70b050e1f92534b6cd8de770d8 --- /dev/null +++ b/configs/experiment/example.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: dti_data + - override /data/drug_featurizer: onehot + - override /data/protein_featurizer: onehot + - override /model: dti_model + - override /model/protein_encoder: cnn + - override /model/drug_encoder: cnn + - override /model/decoder: concat_mlp + - override /callbacks: default + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["dti"] + +seed: 12345 + +data: + data_file: davis.csv + batch_size: 64 + +model: + optimizer: + lr: 0.0001 + +trainer: + min_epochs: 1 + max_epochs: 100 + accelerator: gpu \ No newline at end of file diff --git a/configs/experiment/ion_channels.yaml b/configs/experiment/ion_channels.yaml new file mode 100644 index 0000000000000000000000000000000000000000..261d7dac49ccb908a63719c01fd04374e4d21019 --- /dev/null +++ b/configs/experiment/ion_channels.yaml @@ -0,0 +1,9 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + +data: + train_val_test_split: [dti_benchmark/ChEMBL33/train/Ion_channels_train_data.csv, + dti_benchmark/ChEMBL33/valid/Ion_channels_valid_data.csv, + dti_benchmark/ChEMBL33/test/Ion_channels_both_unseen_test_data.csv] diff --git a/configs/experiment/kiba.yaml b/configs/experiment/kiba.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f3a83942f41b75c800bf467690b88b8397955615 --- /dev/null +++ b/configs/experiment/kiba.yaml @@ -0,0 +1,9 @@ +# @package _global_ +defaults: + - dti_experiment + - /task: binary + +data: + train_val_test_split: [dti_benchmark/random_split_update/kiba_train.csv, + dti_benchmark/random_split_update/kiba_valid.csv, + dti_benchmark/random_split_update/kiba_test.csv] diff --git a/configs/experiment/kinase.yaml b/configs/experiment/kinase.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c7d129aca1c9d93934efbe61a243cf7603d30143 --- /dev/null +++ b/configs/experiment/kinase.yaml @@ -0,0 +1,13 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + +data: + train_val_test_split: + - dti_benchmark/ChEMBL33/train/kinase_train_data.csv + - null + - null +# dti_benchmark/ChEMBL33/valid/kinase_valid_data.csv, +# dti_benchmark/ChEMBL33/test/kinase_both_unseen_test_data.csv + diff --git a/configs/experiment/membrane_receptors.yaml b/configs/experiment/membrane_receptors.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a06d357351dba800c885d7d92e51a7a7ccffc474 --- /dev/null +++ b/configs/experiment/membrane_receptors.yaml @@ -0,0 +1,13 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + +data: + train_val_test_split: + - dti_benchmark/ChEMBL33/train/Membrane_receptor_train_data.csv + - null + - null +# dti_benchmark/ChEMBL33/valid/Membrane_receptor_valid_data.csv, +# dti_benchmark/ChEMBL33/test/Membrane_receptor_drug_repo_test_data.csv + diff --git a/configs/experiment/non_kinase_enzymes.yaml b/configs/experiment/non_kinase_enzymes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..314eeb66a368dfcab29313a647e620ee5bcd7cd9 --- /dev/null +++ b/configs/experiment/non_kinase_enzymes.yaml @@ -0,0 +1,13 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + +data: + train_val_test_split: + - dti_benchmark/ChEMBL33/train/Non_kinase_enzyme_train_data.csv + - null + - null +# dti_benchmark/ChEMBL33/valid/Non_kinase_enzyme_valid_data.csv, +# dti_benchmark/ChEMBL33/test/Non_kinase_enzyme_both_unseen_test_data.csv + diff --git a/configs/experiment/nuclear_receptors.yaml b/configs/experiment/nuclear_receptors.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2edf8a4dea610a9a7c2a9b73178f452888bb20c2 --- /dev/null +++ b/configs/experiment/nuclear_receptors.yaml @@ -0,0 +1,9 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + +data: + train_val_test_split: [dti_benchmark/ChEMBL33/train/Nuclear_receptors_train_data.csv, + dti_benchmark/ChEMBL33/valid/Nuclear_receptors_valid_data.csv, + dti_benchmark/ChEMBL33/test/Nuclear_receptors_both_unseen_test_data.csv] diff --git a/configs/experiment/other_protein_targets.yaml b/configs/experiment/other_protein_targets.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ffe3f295621ee3ffc160521e88e2bdf547719190 --- /dev/null +++ b/configs/experiment/other_protein_targets.yaml @@ -0,0 +1,9 @@ +# @package _global_ +defaults: + - dti_experiment + - override /task: binary + +data: + train_val_test_split: [dti_benchmark/ChEMBL33/train/Other_protein_targets_train_data.csv, + dti_benchmark/ChEMBL33/valid/Other_protein_targets_valid_data.csv, + dti_benchmark/ChEMBL33/test/Other_protein_targets_both_unseen_test_data.csv] diff --git a/configs/extras/default.yaml b/configs/extras/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9c6b622283a647fbc513166fc14f016cc3ed8a0 --- /dev/null +++ b/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: True + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/configs/hydra/callbacks/csv_experiment_summary.yaml b/configs/hydra/callbacks/csv_experiment_summary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1b0d00b4f7c73dfdc02846e428fc6793133cc549 --- /dev/null +++ b/configs/hydra/callbacks/csv_experiment_summary.yaml @@ -0,0 +1,3 @@ +csv_experiment_summary: + _target_: deepscreen.utils.hydra.CSVExperimentSummary + prefix: ['test/', 'epoch'] \ No newline at end of file diff --git a/configs/hydra/callbacks/default.yaml b/configs/hydra/callbacks/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e7617d233fb7f325cdd03caf82bcc0e2ba6bf6d --- /dev/null +++ b/configs/hydra/callbacks/default.yaml @@ -0,0 +1,2 @@ +defaults: + - csv_experiment_summary diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d76171f4585f3c26e6c249ba9d4e1617994b421 --- /dev/null +++ b/configs/hydra/default.yaml @@ -0,0 +1,27 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override callbacks: default + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${job_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S-%f}_[${eval:'",".join(${tags})'}] +sweep: + dir: ${paths.log_dir}/${job_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S-%f}_[${eval:'",".join(${tags})'}] + # Sanitize override_dirname by replacing unsafe characters to avoid unintended subdirectory creation + subdir: ${sanitize_path:'${hydra:job.id}-${hydra:job.override_dirname}'} + +job_logging: + handlers: + file: + filename: ${hydra:runtime.output_dir}/${hydra.job.name}.log + +job: + config: + override_dirname: + kv_sep: '=' + item_sep: ';' + exclude_keys: ['tags', 'sweep', 'data.data_file', 'data.train_val_test_split', 'ckpt_path', 'trainer'] diff --git a/configs/hydra/launcher/submitit_local_example.yaml b/configs/hydra/launcher/submitit_local_example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f47752804b52ea1216d750ac3229d12e463aaff --- /dev/null +++ b/configs/hydra/launcher/submitit_local_example.yaml @@ -0,0 +1,12 @@ +# @package _global_ +defaults: + - submitit_local + +submitit_folder: ${hydra.sweep.dir}/.submitit/%j +timeout_min: 60 +cpus_per_task: 1 +gpus_per_node: 1 +tasks_per_node: 8 +mem_gb: 16 +nodes: 1 +name: ${hydra.job.name} diff --git a/configs/hydra/launcher/submitit_slurm_example.yaml b/configs/hydra/launcher/submitit_slurm_example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1dec61801ed47679ea6598ef2540965e110249c4 --- /dev/null +++ b/configs/hydra/launcher/submitit_slurm_example.yaml @@ -0,0 +1,31 @@ +# @package _global_ +defaults: + - submitit_slurm + +hydra: + mode: "MULTIRUN" + launcher: + submitit_folder: ${hydra.sweep.dir}/.submitit/%j + timeout_min: null + cpus_per_task: null + gpus_per_node: null + tasks_per_node: 1 + mem_gb: null + nodes: 1 + name: ${hydra.job.name} + partition: null + qos: null + comment: null + constraint: null + exclude: null + gres: null + cpus_per_gpu: null + gpus_per_task: null + mem_per_gpu: null + mem_per_cpu: null + account: null + signal_delay_s: 120 + max_num_timeout: 0 + additional_parameters: {} + array_parallelism: 256 + setup: null diff --git a/configs/hydra/sweeper/optuna_hps.yaml b/configs/hydra/sweeper/optuna_hps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..316cf24ab252557a08a5e82fa8c9c462b9835246 --- /dev/null +++ b/configs/hydra/sweeper/optuna_hps.yaml @@ -0,0 +1,44 @@ +# @package _global_ + +# example batch experiment of some experiment with Optuna: +# python train.py -m sweep=optuna experiment=example + +defaults: + - optuna + +# choose metric which will be optimized by Optuna +# make sure this is the correct name of some metric logged in lightning module! +objective_metrics: ["val/auroc"] + +# here we define Optuna hyperparameter search +# it optimizes for value returned from function with @hydra.main decorator +# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper +hydra: + mode: "MULTIRUN" # set hydra to multirun by default if this config is attached + sweeper: + # storage URL to persist optimization results + # for example, you can use SQLite if you set 'sqlite:///example.db' + storage: null + + # name of the study to persist optimization results + study_name: null + + # number of parallel workers + n_jobs: 1 + + # 'minimize' or 'maximize' the objective + direction: minimize + + # total number of runs that will be executed + n_trials: 20 + + # choose Optuna hyperparameter sampler + # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others + # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html + sampler: + _target_: optuna.samplers.TPESampler + seed: 12345 + n_startup_trials: 10 # number of random sampling runs before optimization starts + + # define hyperparameter search space + params: ??? diff --git a/configs/local/.gitkeep b/configs/local/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/local/slurm_example.yaml b/configs/local/slurm_example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5600310d09dc747aaba7616f3402c6de096ce918 --- /dev/null +++ b/configs/local/slurm_example.yaml @@ -0,0 +1,21 @@ +# @package _global_ +defaults: + - override /hydra/launcher: submitit_slurm + +hydra: + launcher: + submitit_folder: ${hydra.sweep.dir}/.submitit + name: "${hydra.job.name} ${eval:\"' '.join(${hydra.overrides.task}\")}" + timeout_min: 6000 + cpus_per_task: 4 + gpus_per_task: 1 + gres: gpu:1 + partition: gpu3090 + qos: gpu3090 + additional_parameters: + error: ${hydra.launcher.submitit_folder}/%j_%t.log + output: ${hydra.launcher.submitit_folder}/%j_%t.log + array_parallelism: 256 + +trainer: + precision: bf16 diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebe6c6f19bd172255e9f29148df10a4d2ec42c84 --- /dev/null +++ b/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: lightning.pytorch.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "deepscreen" + rest_api_key: null + # experiment_name: "" + experiment_key: null # set to resume experiment + offline: False + prefix: "" \ No newline at end of file diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bf9af6fb5e70b1382a1421397b12ea677dd8926f --- /dev/null +++ b/configs/logger/csv.yaml @@ -0,0 +1,8 @@ +# csv logger built in lightning + +csv: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" + version: "" diff --git a/configs/logger/default.yaml b/configs/logger/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bfdb2248689fcc8c5827c6c99c0d42bfad37b3c8 --- /dev/null +++ b/configs/logger/default.yaml @@ -0,0 +1,5 @@ +defaults: + - csv +# - mlflow +# - wandb +# - comet \ No newline at end of file diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bed42abc9cd686d89137e236a009c2c9277f2608 --- /dev/null +++ b/configs/logger/mlflow.yaml @@ -0,0 +1,12 @@ +# https://mlflow.org + +mlflow: + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: "file://${paths.output_dir}/mlflow/" # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI + tags: ${tags} + # save_dir: "./mlruns" + prefix: "" + artifact_location: null + # run_id: "" \ No newline at end of file diff --git a/configs/logger/multiple_loggers.yaml b/configs/logger/multiple_loggers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..40d561d1b723cb314de4272210a96c4343dcace9 --- /dev/null +++ b/configs/logger/multiple_loggers.yaml @@ -0,0 +1,6 @@ +# train with multiple loggers at once +defaults: + - csv + - tensorboard +# - mlflow +# - wandb \ No newline at end of file diff --git a/configs/logger/neptune.yaml b/configs/logger/neptune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..086e85fc59117f4a8035d0c1773e4c856fd7e995 --- /dev/null +++ b/configs/logger/neptune.yaml @@ -0,0 +1,9 @@ +# https://neptune.ai + +neptune: + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project: username/deepscreen + # name: "" + log_model_checkpoints: True + prefix: "" \ No newline at end of file diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f438664df015f88bd3cabffaa520da3ef98dc379 --- /dev/null +++ b/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.output_dir}/tensorboard/" + name: null + log_graph: False + default_hp_metric: True + prefix: "" + version: "" diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ec9ba26d7673c20005d8e67e3bed2b647c80218 --- /dev/null +++ b/configs/logger/wandb.yaml @@ -0,0 +1,16 @@ +# https://wandb.ai + +wandb: + _target_: lightning.pytorch.loggers.wandb.WandbLogger + # name: "" # name of the run (normally generated by wandb) + save_dir: "${paths.output_dir}" + offline: True + id: null # pass correct id to resume experiment! + anonymous: null # enable anonymous logging + project: "deepscreen" + log_model: False # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + # entity: "" # set to name of your wandb team + group: "" + tags: ${tags} + job_type: "" \ No newline at end of file diff --git a/configs/model/dti_model.yaml b/configs/model/dti_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72b1c9ee2b75a2f6ffce9a33b8d42cafd9cc67f0 --- /dev/null +++ b/configs/model/dti_model.yaml @@ -0,0 +1,12 @@ +_target_: deepscreen.models.dti.DTILightningModule + +defaults: + - _self_ + - optimizer: adam + - scheduler: default + - predictor: none + - metrics: dti_metrics + +out: ${task.out} +loss: ${task.loss} +activation: ${task.activation} diff --git a/configs/model/loss/multitask_loss.yaml b/configs/model/loss/multitask_loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6dfe6353808f3f91a42e151a34a3cae056e0fc7e --- /dev/null +++ b/configs/model/loss/multitask_loss.yaml @@ -0,0 +1,7 @@ +_target_: deepscreen.models.loss.multitask_loss.MultitaskLoss + +loss_fns: + - _target_: torch.nn.MSELoss + - _target_: torch.nn.CrossEntropyLoss + weight: null +reduction: sum \ No newline at end of file diff --git a/configs/model/metrics/accuracy.yaml b/configs/model/metrics/accuracy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..80a9d3f8a6837571b70d0cfb70ba7c6f59ee1c3e --- /dev/null +++ b/configs/model/metrics/accuracy.yaml @@ -0,0 +1,4 @@ +accuracy: + _target_: torchmetrics.Accuracy + task: ${task.task} + num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/auprc.yaml b/configs/model/metrics/auprc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9de03c65fdeec91e8b39f38dd773d9956e2e8ce --- /dev/null +++ b/configs/model/metrics/auprc.yaml @@ -0,0 +1,4 @@ +auprc: + _target_: torchmetrics.AveragePrecision + task: ${task.task} + num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/auroc.yaml b/configs/model/metrics/auroc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4bcdbbd885ae5ba05120e5568ca7d3a323f213f --- /dev/null +++ b/configs/model/metrics/auroc.yaml @@ -0,0 +1,4 @@ +auroc: + _target_: torchmetrics.AUROC + task: ${task.task} + num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/bedroc.yaml b/configs/model/metrics/bedroc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86c68585882faefde5ee711dfe647d406a06dd09 --- /dev/null +++ b/configs/model/metrics/bedroc.yaml @@ -0,0 +1,3 @@ +bedroc: + _target_: deepscreen.models.metrics.bedroc.BEDROC + alpha: 80.5 \ No newline at end of file diff --git a/configs/model/metrics/ci.yaml b/configs/model/metrics/ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..634ea1906b5d0307e2e0342beec04c28e66af308 --- /dev/null +++ b/configs/model/metrics/ci.yaml @@ -0,0 +1,2 @@ +# FIXME: implement concordance index +_target_: \ No newline at end of file diff --git a/configs/model/metrics/confusion_matrix.yaml b/configs/model/metrics/confusion_matrix.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d0be32b6b92a93e743d3926bf19fc3849c2e9260 --- /dev/null +++ b/configs/model/metrics/confusion_matrix.yaml @@ -0,0 +1 @@ +_target_: torchmetrics.ConfusionMatrix \ No newline at end of file diff --git a/configs/model/metrics/dta_metrics.yaml b/configs/model/metrics/dta_metrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b0f108cb9df18dfcceae6af9fd94789685a54b2 --- /dev/null +++ b/configs/model/metrics/dta_metrics.yaml @@ -0,0 +1,2 @@ +defaults: + - mean_squared_error diff --git a/configs/model/metrics/dti_metrics.yaml b/configs/model/metrics/dti_metrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..93ca0c40a0d77ffb095ab91d52863667bf5746cc --- /dev/null +++ b/configs/model/metrics/dti_metrics.yaml @@ -0,0 +1,14 @@ +# train with many loggers at once + +defaults: + - auroc + - auprc + - specificity + - sensitivity + - precision + - recall + - f1_score +# Common virtual screening metrics: +# - ef +# - bedroc +# - hit_rate diff --git a/configs/model/metrics/ef.yaml b/configs/model/metrics/ef.yaml new file mode 100644 index 0000000000000000000000000000000000000000..82553b414da98c55508db9830d82b12af3db786d --- /dev/null +++ b/configs/model/metrics/ef.yaml @@ -0,0 +1,7 @@ +ef1: + _target_: deepscreen.models.metrics.ef.EF + alpha: 0.01 + +ef5: + _target_: deepscreen.models.metrics.ef.EF + alpha: 0.05 \ No newline at end of file diff --git a/configs/model/metrics/f1_score.yaml b/configs/model/metrics/f1_score.yaml new file mode 100644 index 0000000000000000000000000000000000000000..abfb6e4ca37a9dad399aeb3b1244d958d542e238 --- /dev/null +++ b/configs/model/metrics/f1_score.yaml @@ -0,0 +1,4 @@ +f1_score: + _target_: torchmetrics.F1Score + task: ${task.task} + num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/hit_rate.yaml b/configs/model/metrics/hit_rate.yaml new file mode 100644 index 0000000000000000000000000000000000000000..70976774eb365fdf8b8c1f97bd3ad19d6cb64cb8 --- /dev/null +++ b/configs/model/metrics/hit_rate.yaml @@ -0,0 +1,3 @@ +hit_rate: + _target_: deepscreen.models.metrics.hit_rate.HitRate + alpha: 0.05 diff --git a/configs/model/metrics/mse.yaml b/configs/model/metrics/mse.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d9a18c60d43210b16878ff479b1cfa3168788cf --- /dev/null +++ b/configs/model/metrics/mse.yaml @@ -0,0 +1,2 @@ +mean_squared_error: + _target_: torchmetrics.MeanSquaredError \ No newline at end of file diff --git a/configs/model/metrics/pearson.yaml b/configs/model/metrics/pearson.yaml new file mode 100644 index 0000000000000000000000000000000000000000..918df0507b60ab7ce79635a3c9f5d4c210e0ff88 --- /dev/null +++ b/configs/model/metrics/pearson.yaml @@ -0,0 +1,2 @@ +Pearson: + _target_: torchmetrics.PearsonCorrCoef \ No newline at end of file diff --git a/configs/model/metrics/prc.yaml b/configs/model/metrics/prc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..75e3ee320d5b9a32a9acdb55d564fbacab975088 --- /dev/null +++ b/configs/model/metrics/prc.yaml @@ -0,0 +1,4 @@ +prc: + _target_: torchmetrics.PrecisionRecallCurve + task: ${task.task} + num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/precision.yaml b/configs/model/metrics/precision.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b8212b1999c10b627de5859e034260a39022608 --- /dev/null +++ b/configs/model/metrics/precision.yaml @@ -0,0 +1,4 @@ +precision: + _target_: torchmetrics.Precision + task: ${task.task} + num_classes: ${task.num_classes} diff --git a/configs/model/metrics/recall.yaml b/configs/model/metrics/recall.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eadad752ad1c4e1ac137580a289d4d32916ca975 --- /dev/null +++ b/configs/model/metrics/recall.yaml @@ -0,0 +1,4 @@ +recall: + _target_: torchmetrics.Recall + task: ${task.task} + num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/roc.yaml b/configs/model/metrics/roc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91968a6f42e3d399f39e587785eaddabda523a23 --- /dev/null +++ b/configs/model/metrics/roc.yaml @@ -0,0 +1,4 @@ +roc: + _target_: torchmetrics.ROC + task: ${task.task} + num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/sensitivity.yaml b/configs/model/metrics/sensitivity.yaml new file mode 100644 index 0000000000000000000000000000000000000000..49568b4512c2b75ebfce98d8ee03e2b1148966cc --- /dev/null +++ b/configs/model/metrics/sensitivity.yaml @@ -0,0 +1,4 @@ +sensitivity: + _target_: deepscreen.models.metrics.sensitivity.Sensitivity + task: ${task.task} + num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/specificity.yaml b/configs/model/metrics/specificity.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5b161be947876081ce9b6b070ed7201874f5cbc6 --- /dev/null +++ b/configs/model/metrics/specificity.yaml @@ -0,0 +1,4 @@ +specificity: + _target_: torchmetrics.Specificity + task: ${task.task} + num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/test_metrics.yaml b/configs/model/metrics/test_metrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ff2b80870d15e9863158e6dfbd17cf3dd9586c2 --- /dev/null +++ b/configs/model/metrics/test_metrics.yaml @@ -0,0 +1,11 @@ +# train with many loggers at once + +defaults: + - auroc + - auprc + - roc + - prc +# Common virtual screening metrics: +# - ef +# - bedroc +# - hit_rate diff --git a/configs/model/metrics/ww_dti_metrics.yaml b/configs/model/metrics/ww_dti_metrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2685e108dfa2476c8eb67ef7f41d41dd9910390 --- /dev/null +++ b/configs/model/metrics/ww_dti_metrics.yaml @@ -0,0 +1,193 @@ +defaults: + - auroc + - auprc + +Accuracy0_5: + _target_: torchmetrics.Accuracy + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.5 + +Accuracy0_8: + _target_: torchmetrics.Accuracy + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.8 + +Accuracy0_85: + _target_: torchmetrics.Accuracy + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.85 + +Accuracy0_9: + _target_: torchmetrics.Accuracy + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.9 + +Accuracy0_95: + _target_: torchmetrics.Accuracy + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.95 + + + +Sensitivity0_5: + _target_: deepscreen.models.metrics.sensitivity.Sensitivity + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.5 + +Sensitivity0_8: + _target_: deepscreen.models.metrics.sensitivity.Sensitivity + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.8 + +Sensitivity0_85: + _target_: deepscreen.models.metrics.sensitivity.Sensitivity + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.85 + +Sensitivity0_9: + _target_: deepscreen.models.metrics.sensitivity.Sensitivity + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.9 + +Sensitivity0_95: + _target_: deepscreen.models.metrics.sensitivity.Sensitivity + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.95 + + + +Specificity0_5: + _target_: torchmetrics.Specificity + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.5 + +Specificity0_8: + _target_: torchmetrics.Specificity + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.8 + +Specificity0_85: + _target_: torchmetrics.Specificity + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.85 + +Specificity0_9: + _target_: torchmetrics.Specificity + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.9 + +Specificity0_95: + _target_: torchmetrics.Specificity + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.95 + + + +Precision0_5: + _target_: torchmetrics.Precision + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.5 + +Precision0_8: + _target_: torchmetrics.Precision + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.8 + +Precision0_85: + _target_: torchmetrics.Precision + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.85 + +Precision0_9: + _target_: torchmetrics.Precision + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.9 + +Precision0_95: + _target_: torchmetrics.Precision + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.95 + + + +Recall0_5: + _target_: torchmetrics.Recall + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.5 + +Recall0_8: + _target_: torchmetrics.Recall + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.8 + +Recall0_85: + _target_: torchmetrics.Recall + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.85 + +Recall0_9: + _target_: torchmetrics.Recall + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.9 + +Recall0_95: + _target_: torchmetrics.Recall + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.95 + + + +F1Score0_5: + _target_: torchmetrics.F1Score + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.5 + +F1Score0_8: + _target_: torchmetrics.F1Score + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.8 + +F1Score0_85: + _target_: torchmetrics.F1Score + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.85 + +F1Score0_9: + _target_: torchmetrics.F1Score + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.9 + +F1Score0_95: + _target_: torchmetrics.F1Score + task: ${task.task} + num_classes: ${task.num_classes} + threshold: 0.95 \ No newline at end of file diff --git a/configs/model/optimizer/adam.yaml b/configs/model/optimizer/adam.yaml new file mode 100644 index 0000000000000000000000000000000000000000..67ab7f2cfeeb16daef186434fd4655afab56bf37 --- /dev/null +++ b/configs/model/optimizer/adam.yaml @@ -0,0 +1,5 @@ +_target_: torch.optim.Adam +_partial_: true + +lr: 0.0001 +weight_decay: 0.0 \ No newline at end of file diff --git a/configs/model/optimizer/none.yaml b/configs/model/optimizer/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/model/predictor/custom.yaml b/configs/model/predictor/custom.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d7a867b87888c03a7209b4b78bbdf2dadfff07a2 --- /dev/null +++ b/configs/model/predictor/custom.yaml @@ -0,0 +1,6 @@ +_target_: deepscreen.models.predictors.custom.CustomPredictor + +defaults: + - drug_encoder: cnn + - protein_encoder: cnn + - decoder: concat_mlp diff --git a/configs/model/predictor/decoder/concat_mlp.yaml b/configs/model/predictor/decoder/concat_mlp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17b92c9493a45d51b3463dc069a73295fc22eb4f --- /dev/null +++ b/configs/model/predictor/decoder/concat_mlp.yaml @@ -0,0 +1,6 @@ +_target_: deepscreen.models.components.mlp.ConcatMLP + +input_channels: ${eval:${model.drug_encoder.out_channels}+${model.protein_encoder.out_channels}} +out_channels: 512 +hidden_channels: [1024,1024] +dropout: 0.1 \ No newline at end of file diff --git a/configs/model/predictor/decoder/mlp_deepdta.yaml b/configs/model/predictor/decoder/mlp_deepdta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b6ee28076224a373a6879ece9296491dccfd280c --- /dev/null +++ b/configs/model/predictor/decoder/mlp_deepdta.yaml @@ -0,0 +1,6 @@ +_target_: deepscreen.models.components.mlp.MLP2 + +input_channels: ${eval:${model.drug_encoder.out_channels}+${model.protein_encoder.out_channels}} +out_channels: 1 +hidden_channels: [1024,1024,512] +dropout: 0.1 \ No newline at end of file diff --git a/configs/model/predictor/decoder/mlp_lazy.yaml b/configs/model/predictor/decoder/mlp_lazy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..832863817e37fed6e5ef54eba4de99341310c4dc --- /dev/null +++ b/configs/model/predictor/decoder/mlp_lazy.yaml @@ -0,0 +1,5 @@ +_target_: deepscreen.models.components.mlp.LazyMLP + +out_channels: 1 +hidden_channels: [1024,1024,512] +dropout: 0.1 \ No newline at end of file diff --git a/configs/model/predictor/deep_dta.yaml b/configs/model/predictor/deep_dta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f6cfc42a5e850554e0825d803f79575abdcdb1b --- /dev/null +++ b/configs/model/predictor/deep_dta.yaml @@ -0,0 +1,15 @@ +_target_: deepscreen.models.predictors.deep_dta.DeepDTA + +defaults: + - drug_encoder@drug_cnn: cnn + - protein_encoder@protein_cnn: cnn +# - /model/decoder@fc: concat_mlp + +num_features_drug: 63 +num_features_protein: 26 +embed_dim: 128 + +drug_cnn: + in_channels: ${data.drug_featurizer.max_sequence_length} +protein_cnn: + in_channels: ${data.protein_featurizer.max_sequence_length} \ No newline at end of file diff --git a/configs/model/predictor/drug_encoder/cnn.yaml b/configs/model/predictor/drug_encoder/cnn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..453ef99dec6c2a5db821181e72eef24c9faab966 --- /dev/null +++ b/configs/model/predictor/drug_encoder/cnn.yaml @@ -0,0 +1,9 @@ +_target_: deepscreen.models.components.cnn.CNN + +max_sequence_length: ${data.drug_featurizer.max_sequence_length} +filters: [32, 64, 96] +kernels: [4, 6, 8] +in_channels: ${data.drug_featurizer.in_channels} +out_channels: 256 + +# TODO refactor the in_channels argument pipeline to be more reasonable \ No newline at end of file diff --git a/configs/model/predictor/drug_encoder/cnn_deepdta.yaml b/configs/model/predictor/drug_encoder/cnn_deepdta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..97bf3a4870e224b2bc7a5eab946d38b57c279d26 --- /dev/null +++ b/configs/model/predictor/drug_encoder/cnn_deepdta.yaml @@ -0,0 +1,7 @@ +_target_: deepscreen.models.components.cnn_deepdta.CNN_DeepDTA + +max_sequence_length: ${data.drug_featurizer.max_sequence_length} +filters: [32, 64, 96] +kernels: [4, 6, 8] +in_channels: ${data.drug_featurizer.in_channels} +out_channels: 128 \ No newline at end of file diff --git a/configs/model/predictor/drug_encoder/gat.yaml b/configs/model/predictor/drug_encoder/gat.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b8dddd5ab5fac2b3cefff8191c01d3ab393c8a5 --- /dev/null +++ b/configs/model/predictor/drug_encoder/gat.yaml @@ -0,0 +1,5 @@ +_target_: deepscreen.models.components.gat.GAT + +num_features: 78 +out_channels: 128 +dropout: 0.2 \ No newline at end of file diff --git a/configs/model/predictor/drug_encoder/gcn.yaml b/configs/model/predictor/drug_encoder/gcn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e2da337b324610297bd76b6783340d03bd681a8 --- /dev/null +++ b/configs/model/predictor/drug_encoder/gcn.yaml @@ -0,0 +1,5 @@ +_target_: deepscreen.models.components.gcn.GCN + +num_features: 78 +out_channels: 128 +dropout: 0.2 \ No newline at end of file diff --git a/configs/model/predictor/drug_encoder/gin.yaml b/configs/model/predictor/drug_encoder/gin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..caf5820c158ff7956b2440d25f7b5f901936f683 --- /dev/null +++ b/configs/model/predictor/drug_encoder/gin.yaml @@ -0,0 +1,5 @@ +_target_: deepscreen.models.components.gin.GIN + +num_features: 78 +out_channels: 128 +dropout: 0.2 \ No newline at end of file diff --git a/configs/model/predictor/drug_encoder/lstm.yaml b/configs/model/predictor/drug_encoder/lstm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/model/predictor/drug_encoder/transformer.yaml b/configs/model/predictor/drug_encoder/transformer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5eee1571675bcb2d2d539dbaa69ea0268c7b3908 --- /dev/null +++ b/configs/model/predictor/drug_encoder/transformer.yaml @@ -0,0 +1,11 @@ +_target_: deepscreen.models.components.transformer + +input_dim: 1024 +emb_size: 128 +max_position_size: 50 +dropout: 0.1 +n_layer: 8 +intermediate_size: 512 +num_attention_heads: 8 +attention_probs_dropout: 0.1 +hidden_dropout: 0.1 \ No newline at end of file diff --git a/configs/model/predictor/drug_vqa.yaml b/configs/model/predictor/drug_vqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8ee31b6b3e5a62bec7998d3bdb4c4619ecd27e7f --- /dev/null +++ b/configs/model/predictor/drug_vqa.yaml @@ -0,0 +1,15 @@ +_target_: deepscreen.models.predictors.drug_vqa.DrugVQA + +conv_dim: 1 +lstm_hid_dim: 64 +d_a: 32 +r: 10 +n_chars_smi: 247 +n_chars_seq: 21 +dropout: 0.2 +in_channels: 8 +cnn_channels: 32 +cnn_layers: 4 +emb_dim: 30 +dense_hid: 64 + diff --git a/configs/model/predictor/graph_dta.yaml b/configs/model/predictor/graph_dta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35bfdf8adc5ab69b780e81c828a195e48e8ec1a0 --- /dev/null +++ b/configs/model/predictor/graph_dta.yaml @@ -0,0 +1,16 @@ +defaults: + - drug_encoder@gnn: gat + - _self_ + +_target_: deepscreen.models.predictors.graph_dta.GraphDTA + +gnn: + num_features: 34 + out_channels: 128 + dropout: 0.2 + +num_features_protein: 26 +n_filters: 32 +embed_dim: 128 +output_dim: 128 +dropout: 0.2 diff --git a/configs/model/predictor/hyper_attention_dti.yaml b/configs/model/predictor/hyper_attention_dti.yaml new file mode 100644 index 0000000000000000000000000000000000000000..87dde659688073d247953fce5ceef65a38786789 --- /dev/null +++ b/configs/model/predictor/hyper_attention_dti.yaml @@ -0,0 +1,8 @@ +_target_: deepscreen.models.predictors.hyper_attention_dti.HyperAttentionDTI + +protein_kernel: [4,8,12] +drug_kernel: [4,6,8] +conv: 40 +char_dim: 64 +protein_max_len: ${data.protein_featurizer.max_sequence_length} +drug_max_len: ${data.drug_featurizer.max_sequence_length} \ No newline at end of file diff --git a/configs/model/predictor/none.yaml b/configs/model/predictor/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/model/predictor/protein_encoder/cnn.yaml b/configs/model/predictor/protein_encoder/cnn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9363e5e6130a5292d173b90eba8430c12edef3d1 --- /dev/null +++ b/configs/model/predictor/protein_encoder/cnn.yaml @@ -0,0 +1,7 @@ +_target_: deepscreen.models.components.cnn.CNN + +max_sequence_length: ${data.protein_featurizer.max_sequence_length} +filters: [32, 64, 96] +kernels: [4, 8, 12] +in_channels: ${data.protein_featurizer.in_channels} +out_channels: 256 \ No newline at end of file diff --git a/configs/model/predictor/protein_encoder/cnn_deepdta.yaml b/configs/model/predictor/protein_encoder/cnn_deepdta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8ac5f6d9064695cdfe2739946c1e55d55aa588d5 --- /dev/null +++ b/configs/model/predictor/protein_encoder/cnn_deepdta.yaml @@ -0,0 +1,7 @@ +_target_: deepscreen.models.components.cnn_deepdta.CNN_DeepDTA + +max_sequence_length: ${data.protein_featurizer.max_sequence_length} +filters: [32, 64, 96] +kernels: [4, 8, 12] +in_channels: ${data.protein_featurizer.in_channels} +out_channels: 128 \ No newline at end of file diff --git a/configs/model/predictor/protein_encoder/lstm.yaml b/configs/model/predictor/protein_encoder/lstm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/model/predictor/protein_encoder/tape_bert.yaml b/configs/model/predictor/protein_encoder/tape_bert.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a64b46c70850a5d796a19615ba553e69cb26a39 --- /dev/null +++ b/configs/model/predictor/protein_encoder/tape_bert.yaml @@ -0,0 +1,3 @@ +_target_: tape.ProteinBertModel.from_pretrained + +pretrained_model_name_or_path: bert-base \ No newline at end of file diff --git a/configs/model/predictor/protein_encoder/transformer.yaml b/configs/model/predictor/protein_encoder/transformer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4fb7c9761bc098da84773f56bc949ce7c7d34c5f --- /dev/null +++ b/configs/model/predictor/protein_encoder/transformer.yaml @@ -0,0 +1,12 @@ +_target_: deepscreen.models.components.transformer + +input_dim: 8420 +emb_size: 64 +max_position_size: 545 50 +dropout: 0.1 +n_layer: 2 +intermediate_size: 256 +num_attention_heads: 4 +attention_probs_dropout: 0.1 +hidden_dropout: 0.1 + diff --git a/configs/model/predictor/transformer_cpi.yaml b/configs/model/predictor/transformer_cpi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c7d137ba108186e77e67878c4ad6ba40dbad52d3 --- /dev/null +++ b/configs/model/predictor/transformer_cpi.yaml @@ -0,0 +1,10 @@ +_target_: deepscreen.models.predictors.transformer_cpi.TransformerCPI + +protein_dim: 100 +atom_dim: 34 +hid_dim: 64 +n_layers: 3 +n_heads: 8 +pf_dim: 256 +dropout: 0.1 +kernel_size: 7 \ No newline at end of file diff --git a/configs/model/predictor/transformer_cpi_2.yaml b/configs/model/predictor/transformer_cpi_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f9439a99c02720650635f2024acb36ef45b18fa --- /dev/null +++ b/configs/model/predictor/transformer_cpi_2.yaml @@ -0,0 +1,14 @@ +_target_: deepscreen.models.predictors.transformer_cpi_2.TransformerCPI2 + +encoder: + _target_: deepscreen.models.predictors.transformer_cpi_2.Encoder + # /model/protein_encoder@pretrain: tape_bert + n_layers: 3 + pretrain: + _target_: tape.ProteinBertModel.from_pretrained + pretrained_model_name_or_path: bert-base + +decoder: + _target_: deepscreen.models.predictors.transformer_cpi_2.Decoder + n_layers: 3 + dropout: 0.1 diff --git a/configs/model/scheduler/default.yaml b/configs/model/scheduler/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a319ea4099d233cec67de9477bbce25744368a91 --- /dev/null +++ b/configs/model/scheduler/default.yaml @@ -0,0 +1,11 @@ +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + + mode: min + factor: 0.1 + patience: 10 + +monitor: ${oc.select:callbacks.model_checkpoint.monitor,"val/loss"} +interval: "epoch" +frequency: 1 \ No newline at end of file diff --git a/configs/model/scheduler/none.yaml b/configs/model/scheduler/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/model/scheduler/reduce_lr_on_plateau.yaml b/configs/model/scheduler/reduce_lr_on_plateau.yaml new file mode 100644 index 0000000000000000000000000000000000000000..221efd82d131473c6a51da966d153bb2ec85b69c --- /dev/null +++ b/configs/model/scheduler/reduce_lr_on_plateau.yaml @@ -0,0 +1,6 @@ +_target_: torch.optim.lr_scheduler.ReduceLROnPlateau +_partial_: true + +mode: min +factor: 0.1 +patience: 10 diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..862261ca7797da591715936cfd88a047c1722da3 --- /dev/null +++ b/configs/paths/default.yaml @@ -0,0 +1,19 @@ +# path to root directory +root_dir: . + +# path to data directory +data_dir: ${paths.root_dir}/data + +# path to log directory +log_dir: ${paths.root_dir}/logs + +# path to resource directory +resource_dir: ${paths.root_dir}/resources + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} diff --git a/configs/preset/bacpi.yaml.bak b/configs/preset/bacpi.yaml.bak new file mode 100644 index 0000000000000000000000000000000000000000..9dbea130fd0f6b6fe07f009c340f33f6f8f6c558 --- /dev/null +++ b/configs/preset/bacpi.yaml.bak @@ -0,0 +1,37 @@ +# @package _global_ +model: + predictor: + _target_: deepscreen.models.predictors.bacpi.BACPI + + n_atom: 20480 + n_amino: 8448 + comp_dim: 80 + prot_dim: 80 + latent_dim: 80 + gat_dim: 50 + num_head: 3 + dropout: 0.1 + alpha: 0.1 + window: 5 + layer_cnn: 3 + optimizer: + lr: 5e-4 + +data: + batch_size: 16 + + collator: + automatic_padding: True + + drug_featurizer: + _target_: deepscreen.models.predictors.bacpi.drug_featurizer + _partial_: true + radius: 2 + + protein_featurizer: + _target_: deepscreen.models.predictors.bacpi.split_sequence + _partial_: true + ngram: 3 +# collator: +# _target_: deepscreen.models.predictors.transformer_cpi_2.pack +# _partial_: true diff --git a/configs/preset/coa_dti_pro.yaml.bak b/configs/preset/coa_dti_pro.yaml.bak new file mode 100644 index 0000000000000000000000000000000000000000..9ad501a6176a630b8edb97653607ef223946367d --- /dev/null +++ b/configs/preset/coa_dti_pro.yaml.bak @@ -0,0 +1,28 @@ +# @package _global_ +defaults: + - override /data/protein_featurizer: none + +model: + predictor: + _target_: deepscreen.models.predictors.coa_dti_pro.CoaDTIPro + + n_fingerprint: 20480 + n_word: 26 + dim: 512 + layer_output: 3 + layer_coa: 1 + nhead: 8 + dropout: 0.1 + co_attention: 'inter' + gcn_pooling: False + + esm_model_and_alphabet: + _target_: esm.pretrained.load_model_and_alphabet + model_name: resources/models/esm/esm1_t6_43M_UR50S.pt + +data: + drug_featurizer: + _target_: deepscreen.models.predictors.coa_dti_pro.drug_featurizer + _partial_: true + radius: 2 + batch_size: 1 diff --git a/configs/preset/deep_conv_dti.yaml b/configs/preset/deep_conv_dti.yaml new file mode 100644 index 0000000000000000000000000000000000000000..025b151cf1539205f245feb41f9fd60f133c4497 --- /dev/null +++ b/configs/preset/deep_conv_dti.yaml @@ -0,0 +1,23 @@ +# @package _global_ +defaults: + - override /data/drug_featurizer: ecfp + - override /data/protein_featurizer: label + +model: + predictor: + _target_: deepscreen.models.predictors.deep_conv_dti.DeepConvDTI + + activation: + _target_: torch.nn.ELU + + dropout: 0.0 + drug_layers: [512, 128] + protein_windows: [10, 15, 20, 25, 30] + n_filters: 128 + decay: 0.0001 + convolution: true + protein_layers: [128,] + fc_layers: [128,] + +data: + batch_size: 512 diff --git a/configs/preset/deep_dta.yaml b/configs/preset/deep_dta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f72d5c942226d88e798cb6a2fd6fef080df69053 --- /dev/null +++ b/configs/preset/deep_dta.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - override /data/drug_featurizer: label + - override /data/protein_featurizer: label + - override /model/predictor: deep_dta + +data: + batch_size: 512 diff --git a/configs/preset/drug_ban.yaml b/configs/preset/drug_ban.yaml new file mode 100644 index 0000000000000000000000000000000000000000..537c9d79276383542e956ff4ef27178c4c358fc4 --- /dev/null +++ b/configs/preset/drug_ban.yaml @@ -0,0 +1,28 @@ +# @package _global_ +defaults: + - override /data/protein_featurizer: label + +model: + predictor: + _target_: deepscreen.models.predictors.drug_ban.DrugBAN + + drug_in_feats: 75 + drug_embedding: 128 + drug_hidden_feats: [128, 128, 128] + drug_padding: True + protein_emb_dim: 128 + num_filters: [128, 128, 128] + kernel_size: [3, 6, 9] + protein_padding: True + mlp_in_dim: 256 + mlp_hidden_dim: 512 + mlp_out_dim: 128 + ban_heads: 2 + +data: + drug_featurizer: + _target_: deepscreen.models.predictors.drug_ban.drug_featurizer + _partial_: true + max_drug_nodes: 330 + + batch_size: 512 diff --git a/configs/preset/drug_vqa.yaml b/configs/preset/drug_vqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26361c45c25c936a1d9c9b28942343ba7a96a127 --- /dev/null +++ b/configs/preset/drug_vqa.yaml @@ -0,0 +1,21 @@ +# @package _global_ +defaults: + - override /data/drug_featurizer: tokenizer + - override /data/protein_featurizer: label + - override /model/predictor: drug_vqa + +model: + loss: + _target_: deepscreen.models.loss.multitask_loss.MultitaskWeightedLoss + loss_fns: + - ${task.loss} + - _target_: deepscreen.models.predictors.drug_vqa.AttentionL2Regularization + weights: [1, 0.001] + +data: + batch_size: 512 + drug_featurizer: + tokenizer: + _target_: deepscreen.data.featurizers.token.SmilesTokenizer + vocab_file: resources/vocabs/drug_vqa/combinedVoc-wholeFour.voc + regex_pattern: '(\[[^\[\]]{1,10}\])' \ No newline at end of file diff --git a/configs/preset/graph_dta.yaml b/configs/preset/graph_dta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..159cb4e619f252663dccda18294600602674e01d --- /dev/null +++ b/configs/preset/graph_dta.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - override /data/drug_featurizer: graph + - override /data/protein_featurizer: label + - override /model/predictor: graph_dta + +data: + batch_size: 512 \ No newline at end of file diff --git a/configs/preset/hyper_attention_dti.yaml b/configs/preset/hyper_attention_dti.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae63593517d0ce34310ba3f0dbf1201f66e7e3a5 --- /dev/null +++ b/configs/preset/hyper_attention_dti.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - override /data/drug_featurizer: label + - override /data/protein_featurizer: label + - override /model/predictor: hyper_attention_dti + +data: + batch_size: 32 \ No newline at end of file diff --git a/configs/preset/m_graph_dta.yaml b/configs/preset/m_graph_dta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8609b499b94168c7c5da2b5779b310be5c769020 --- /dev/null +++ b/configs/preset/m_graph_dta.yaml @@ -0,0 +1,22 @@ +# @package _global_ +defaults: + - override /data/drug_featurizer: graph + - override /data/protein_featurizer: label + +model: + predictor: + _target_: deepscreen.models.predictors.m_graph_dta.MGraphDTA + block_num: 3 + vocab_protein_size: ${eval:'len(${data.protein_featurizer.charset})+1'} + embedding_size: 128 + filter_num: 32 + +data: + drug_featurizer: + atom_features: + _target_: deepscreen.models.predictors.m_graph_dta.atom_features + _partial_: true + batch_size: 512 + +trainer: + precision: 'bf16' \ No newline at end of file diff --git a/configs/preset/mol_trans.yaml b/configs/preset/mol_trans.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ed7fcf5ce9d43b0e047b9720cb544faf707fb64f --- /dev/null +++ b/configs/preset/mol_trans.yaml @@ -0,0 +1,39 @@ +# @package _global_ +defaults: + - override /data/drug_featurizer: fcs + - override /data/protein_featurizer: fcs + +data: + batch_size: 16 + drug_featurizer: + max_sequence_length: 205 + protein_featurizer: + max_sequence_length: 545 + +model: + predictor: + _target_: deepscreen.models.predictors.mol_trans.MolTrans + + input_dim_drug: 23532 + input_dim_target: 16693 + max_drug_seq: ${data.drug_featurizer.max_sequence_length} + max_protein_seq: ${data.protein_featurizer.max_sequence_length} + emb_size: 384 + dropout_rate: 0.1 + + # DenseNet + scale_down_ratio: 0.25 + growth_rate: 20 + transition_rate: 0.5 + num_dense_blocks: 4 + kernal_dense_size: 3 + + # Encoder + intermediate_size: 1536 + num_attention_heads: 12 + attention_probs_dropout_prob: 0.1 + hidden_dropout_prob: 0.1 + #flatten_dim: 293412 + + optimizer: + lr: 1e-6 \ No newline at end of file diff --git a/configs/preset/transformer_cpi.yaml b/configs/preset/transformer_cpi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c218ea5a4c51454e9fa4eeb45f97150ac2b99aae --- /dev/null +++ b/configs/preset/transformer_cpi.yaml @@ -0,0 +1,21 @@ +# @package _global_ +defaults: + - override /data/drug_featurizer: mol_features + - override /data/protein_featurizer: word2vec + +model: + predictor: + _target_: deepscreen.models.predictors.transformer_cpi.TransformerCPI + protein_dim: 100 + hidden_dim: 64 + n_layers: 3 + kernel_size: 5 + dropout: 0.1 + n_heads: 8 + pf_dim: 256 + atom_dim: 34 + +data: + batch_size: 16 + collator: + automatic_padding: True diff --git a/configs/preset/transformer_cpi_2.yaml b/configs/preset/transformer_cpi_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1c224f88a6de1266ee4d836c15ff5e88b06f119 --- /dev/null +++ b/configs/preset/transformer_cpi_2.yaml @@ -0,0 +1,35 @@ +# @package _global_ +defaults: + - override /data/drug_featurizer: mol_features + - override /data/protein_featurizer: tokenizer + +model: + predictor: + _target_: deepscreen.models.predictors.transformer_cpi_2.TransformerCPI2 + + encoder: + _target_: deepscreen.models.predictors.transformer_cpi_2.Encoder + # /model/protein_encoder@pretrain: tape_bert + n_layers: 3 + pretrain: + _target_: tape.ProteinBertModel.from_pretrained + pretrained_model_name_or_path: resources/models/tape/bert-base/ # bert-base + + decoder: + _target_: deepscreen.models.predictors.transformer_cpi_2.Decoder + n_layers: 3 + dropout: 0.1 + +data: + batch_size: 16 + collator: + automatic_padding: True + + protein_featurizer: + tokenizer: + _target_: tape.TAPETokenizer.from_pretrained + vocab: iupac + +# collator: +# _target_: deepscreen.models.predictors.transformer_cpi_2.pack +# _partial_: true diff --git a/configs/sweep/ddp_multirun.yaml b/configs/sweep/ddp_multirun.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fbf6641a20019ae41cf8d231d01433b2562a9c77 --- /dev/null +++ b/configs/sweep/ddp_multirun.yaml @@ -0,0 +1,6 @@ +# @package _global_ +hydra: + sweeper: + params: + preset: graph_dta,deep_dta # drug_vqa,mol_trans,hyper_attention_dti,transformer_cpi_2 + experiment: kiba,davis,bindingdb \ No newline at end of file diff --git a/configs/sweep/dti_benchmark.yaml b/configs/sweep/dti_benchmark.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b404a86d52131cbde1d7b877be92a80d8000d5b8 --- /dev/null +++ b/configs/sweep/dti_benchmark.yaml @@ -0,0 +1,8 @@ +# @package _global_ +tags: ['sweep', 'benchmark'] +hydra: + sweeper: + params: + preset: transformer_cpi_2 #graph_dta,deep_dta #,mol_trans,hyper_attention_dti,m_graph_dta + experiment: other_protein_targets + diff --git a/configs/sweep/example_multirun_test.yaml b/configs/sweep/example_multirun_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a609000341c88314547dfe300ec59e059c0d9825 --- /dev/null +++ b/configs/sweep/example_multirun_test.yaml @@ -0,0 +1,7 @@ +# @package _global_ +hydra: + sweeper: + params: +# ckpt_path: "'C:/Users/libok/Documents/GitHub/DeepScreen/logs/train/runs/2023-11-07_23-21-55-442205_[debug]/checkpoints/epoch_000.ckpt','C:/Users/libok/Documents/GitHub/DeepScreen/logs/train/runs/2023-11-07_19-46-08-740035_[debug]/checkpoints/epoch_000.ckpt'" + data.data_file: dti_benchmark/random_split_update/davis_reserved_test.csv,dti_benchmark/random_split_update/kiba_reserved_test.csv + ckpt_path: "'C:/Users/libok/Documents/GitHub/DeepScreen/logs/test/multiruns/2023-11-10_10-31-15-339335_[multirun,test,dev]/metrics_summary.csv'" \ No newline at end of file diff --git a/configs/sweep/example_multirun_train.yaml b/configs/sweep/example_multirun_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2d7edad480a88165ca51475c7fd076110d2fef0 --- /dev/null +++ b/configs/sweep/example_multirun_train.yaml @@ -0,0 +1,6 @@ +# @package _global_ +hydra: + sweeper: + params: + preset: graph_dta,deep_dta + experiment: chembl_random diff --git a/configs/sweep/example_submitit.yaml b/configs/sweep/example_submitit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..510522a9cb855b9afd2ac3ebd6a0f50baea93459 --- /dev/null +++ b/configs/sweep/example_submitit.yaml @@ -0,0 +1,17 @@ +# @package _global_ +defaults: + - basic + - override /hydra/launcher: submitit_local + +hydra: + sweeper: + params: + preset: graph_dta,deep_dta + experiment: chembl_random + launcher: + timeout_min: 60 + cpus_per_task: 1 + gpus_per_node: 1 + tasks_per_node: 4 + mem_gb: 32 + nodes: 1 \ No newline at end of file diff --git a/configs/sweep/experiment_summary.csv b/configs/sweep/experiment_summary.csv new file mode 100644 index 0000000000000000000000000000000000000000..73e924007e4c3936dadc918f7cece50574336c3e --- /dev/null +++ b/configs/sweep/experiment_summary.csv @@ -0,0 +1,21 @@ +test/loss,test/auprc,test/auroc,test/f1_score,test/precision,test/recall,test/sensitivity,test/specificity,ckpt_path,job_status,preset,experiment,sweep,tags,model.optimizer.lr,local,data.batch_size +0.1525115370750427,0.735183835029602,0.9363504648208618,0.661261260509491,0.7816826701164246,0.5729898810386658,0.5729898810386658,0.984080135822296,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=kiba;preset=deep_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_041.ckpt",COMPLETED,deep_dta,kiba,slurm_test,"[benchmark,ddta,gdta]",,, +0.193635880947113,0.4078962802886963,0.8581838607788086,0.3916666805744171,0.5371428728103638,0.3081967234611511,0.3081967234611511,0.9788622260093688,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=davis;preset=deep_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_008.ckpt",COMPLETED,deep_dta,davis,slurm_test,"[benchmark,ddta,gdta]",,, +0.5104637742042542,0.7619075775146484,0.8163812756538391,0.6808292269706726,0.6954700350761414,0.6667922139167786,0.6667922139167786,0.789406955242157,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=bindingdb;preset=deep_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_003.ckpt",COMPLETED,deep_dta,bindingdb,slurm_test,"[benchmark,ddta,gdta]",,, +0.2221491634845733,0.3560383319854736,0.78690105676651,0.2659846544265747,0.604651153087616,0.1704917997121811,0.1704917997121811,0.9911273717880248,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=davis;preset=graph_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_054.ckpt",COMPLETED,graph_dta,davis,slurm_test,"[benchmark,ddta,gdta]",,, +0.2281630784273147,0.4483628869056701,0.8383051156997681,0.2730720639228821,0.7176079750061035,0.1686182618141174,0.1686182618141174,0.9933990836143494,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=kiba;preset=graph_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_065.ckpt",COMPLETED,graph_dta,kiba,slurm_test,"[benchmark,ddta,gdta]",,, +0.4569543302059173,0.828099250793457,0.8623664379119873,0.7299983501434326,0.7756586670875549,0.6894148588180542,0.6894148588180542,0.8561919927597046,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=bindingdb;preset=graph_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_034.ckpt",COMPLETED,graph_dta,bindingdb,slurm_test,"[benchmark,ddta,gdta]",,, +0.1881016194820404,0.5213174819946289,0.8876519203186035,0.5207100510597229,0.6534653306007385,0.4327868819236755,0.4327868819236755,0.9817327857017516,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_18-08-40-028557_[sweep,benchmark]/experiment=davis;preset=hyper_attention_dti;sweep=slurm_test/checkpoints/epoch_017.ckpt",COMPLETED,hyper_attention_dti,davis,slurm_test,,,, +0.144176036119461,0.7704265117645264,0.9501243829727172,0.6897732019424438,0.7632575631141663,0.6291959285736084,0.6291959285736084,0.980585515499115,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_18-08-40-028557_[sweep,benchmark]/experiment=kiba;preset=hyper_attention_dti;sweep=slurm_test/checkpoints/epoch_014.ckpt",COMPLETED,hyper_attention_dti,kiba,slurm_test,,,, +0.3445079922676086,0.915877878665924,0.934882640838623,0.8371888399124146,0.8385065793991089,0.8358752131462097,0.8358752131462097,0.8838841319084167,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_18-08-40-028557_[sweep,benchmark]/experiment=bindingdb;preset=hyper_attention_dti;sweep=slurm_test/checkpoints/epoch_017.ckpt",COMPLETED,hyper_attention_dti,bindingdb,slurm_test,,,, +0.3915603458881378,0.3016493022441864,0.7910888195037842,0.085626907646656,0.6363636255264282,0.0459016375243663,0.0459016375243663,0.9979122877120972,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_18-50-02-680654_[mol_trans,benchmark]/experiment=davis;preset=mol_trans;sweep=slurm_test/checkpoints/epoch_000.ckpt",COMPLETED,mol_trans,davis,slurm_test,,,, +0.1662539541721344,0.7397991418838501,0.9369943141937256,0.6784178614616394,0.7550238966941833,0.6159250736236572,0.6159250736236572,0.9801195859909058,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_18-50-02-680654_[mol_trans,benchmark]/experiment=kiba;preset=mol_trans;sweep=slurm_test/checkpoints/epoch_008.ckpt",COMPLETED,mol_trans,kiba,slurm_test,,,, +0.2160931378602981,0.3230183720588684,0.796424388885498,0.3271028101444244,0.5691056847572327,0.2295081913471222,0.2295081913471222,0.9861690998077391,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_20-20-11-193061_[tcpi2,kiba,davis]/experiment=davis;model.optimizer.lr=1e-06;preset=transformer_cpi_2;sweep=slurm_test;tags=[tcpi2,kiba,davis]/checkpoints/epoch_023.ckpt",COMPLETED,transformer_cpi_2,davis,slurm_test,"[tcpi2,kiba,davis]",1e-06,, +0.2221459150314331,0.5159928798675537,0.8539965152740479,0.4716618657112121,0.612983763217926,0.3832943141460418,0.3832943141460418,0.9759260416030884,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_20-20-11-193061_[tcpi2,kiba,davis]/experiment=kiba;model.optimizer.lr=1e-06;preset=transformer_cpi_2;sweep=slurm_test;tags=[tcpi2,kiba,davis]/checkpoints/epoch_077.ckpt",COMPLETED,transformer_cpi_2,kiba,slurm_test,"[tcpi2,kiba,davis]",1e-06,, +0.1689525544643402,0.5410239696502686,0.9028607606887816,0.4678111672401428,0.6770186424255371,0.3573770523071289,0.3573770523071289,0.986430048942566,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_21-25-54-050914_[deep_conv_dti,benchmark]/experiment=davis;local=gpu3090;preset=deep_conv_dti;tags=[deep_conv_dti,benchmark]/checkpoints/epoch_012.ckpt",COMPLETED,deep_conv_dti,davis,,"[deep_conv_dti,benchmark]",,gpu3090, +0.1273866593837738,0.8186066746711731,0.95548677444458,0.7409326434135437,0.8289855122566223,0.6697892546653748,0.6697892546653748,0.9862545728683472,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_21-25-54-050914_[deep_conv_dti,benchmark]/experiment=kiba;local=gpu3090;preset=deep_conv_dti;tags=[deep_conv_dti,benchmark]/checkpoints/epoch_026.ckpt",COMPLETED,deep_conv_dti,kiba,,"[deep_conv_dti,benchmark]",,gpu3090, +0.3221746683120727,0.9174473285675048,0.9374850988388062,0.8440826535224915,0.8421315550804138,0.846042811870575,0.846042811870575,0.8856043219566345,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_21-25-54-050914_[deep_conv_dti,benchmark]/experiment=bindingdb;local=gpu3090;preset=deep_conv_dti;tags=[deep_conv_dti,benchmark]/checkpoints/epoch_008.ckpt",COMPLETED,deep_conv_dti,bindingdb,,"[deep_conv_dti,benchmark]",,gpu3090, +0.3808988630771637,0.9017271995544434,0.92254638671875,0.8203105330467224,0.818312406539917,0.822318434715271,0.822318434715271,0.8683114647865295,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-23_23-21-02-704802_[mol_trans,bindingdb]/experiment=bindingdb;local=gpu3090;preset=mol_trans;tags=[mol_trans,bindingdb]/checkpoints/epoch_064.ckpt",COMPLETED,mol_trans,bindingdb,,"[mol_trans,bindingdb]",,gpu3090, +0.202189102768898,0.3614169955253601,0.8401678204536438,0.330232560634613,0.5680000185966492,0.2327868789434433,0.2327868789434433,0.9859081506729126,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-23_23-27-18-387556_[m_graph_dta,benchmark]/experiment=davis;local=gpu3090;preset=m_graph_dta;tags=[m_graph_dta,benchmark]/checkpoints/epoch_094.ckpt",COMPLETED,m_graph_dta,davis,,"[m_graph_dta,benchmark]",,gpu3090, +0.1921116858720779,0.6082723736763,0.8893994092941284,0.4803767800331116,0.7285714149475098,0.3583138287067413,0.3583138287067413,0.9867205023765564,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-23_23-27-18-387556_[m_graph_dta,benchmark]/experiment=kiba;local=gpu3090;preset=m_graph_dta;tags=[m_graph_dta,benchmark]/checkpoints/epoch_246.ckpt",COMPLETED,m_graph_dta,kiba,,"[m_graph_dta,benchmark]",,gpu3090, +0.6574236750602722,0.5444095730781555,0.6328268051147461,0.483252614736557,0.5203881859779358,0.4510641098022461,0.4510641098022461,0.7001991868019104,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-11-06_14-16-25-154180_[benchmark,tcpi2,bindingdb]/data.batch_size=32;experiment=bindingdb;preset=transformer_cpi_2;tags=[benchmark,tcpi2,bindingdb]/checkpoints/epoch_017.ckpt",COMPLETED,transformer_cpi_2,bindingdb,,"[benchmark,tcpi2,bindingdb]",,,32.0 diff --git a/configs/sweep/experiment_summary_bindingdb.csv b/configs/sweep/experiment_summary_bindingdb.csv new file mode 100644 index 0000000000000000000000000000000000000000..b2a3ee1ab0192c82662c03cf59ed07bda14b1862 --- /dev/null +++ b/configs/sweep/experiment_summary_bindingdb.csv @@ -0,0 +1,7 @@ +test/loss,test/auprc,test/auroc,test/f1_score,test/precision,test/recall,test/sensitivity,test/specificity,ckpt_path,job_status,preset,experiment,sweep,tags,model.optimizer.lr,local,data.batch_size +0.5104637742042542,0.7619075775146484,0.8163812756538391,0.6808292269706726,0.6954700350761414,0.6667922139167786,0.6667922139167786,0.789406955242157,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=bindingdb;preset=deep_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_003.ckpt",COMPLETED,deep_dta,bindingdb,slurm_test,"[benchmark,ddta,gdta]",,, +0.4569543302059173,0.828099250793457,0.8623664379119873,0.7299983501434326,0.7756586670875549,0.6894148588180542,0.6894148588180542,0.8561919927597046,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=bindingdb;preset=graph_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_034.ckpt",COMPLETED,graph_dta,bindingdb,slurm_test,"[benchmark,ddta,gdta]",,, +0.3445079922676086,0.915877878665924,0.934882640838623,0.8371888399124146,0.8385065793991089,0.8358752131462097,0.8358752131462097,0.8838841319084167,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_18-08-40-028557_[sweep,benchmark]/experiment=bindingdb;preset=hyper_attention_dti;sweep=slurm_test/checkpoints/epoch_017.ckpt",COMPLETED,hyper_attention_dti,bindingdb,slurm_test,,,, +0.3221746683120727,0.9174473285675048,0.9374850988388062,0.8440826535224915,0.8421315550804138,0.846042811870575,0.846042811870575,0.8856043219566345,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_21-25-54-050914_[deep_conv_dti,benchmark]/experiment=bindingdb;local=gpu3090;preset=deep_conv_dti;tags=[deep_conv_dti,benchmark]/checkpoints/epoch_008.ckpt",COMPLETED,deep_conv_dti,bindingdb,,"[deep_conv_dti,benchmark]",,gpu3090, +0.3808988630771637,0.9017271995544434,0.92254638671875,0.8203105330467224,0.818312406539917,0.822318434715271,0.822318434715271,0.8683114647865295,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-23_23-21-02-704802_[mol_trans,bindingdb]/experiment=bindingdb;local=gpu3090;preset=mol_trans;tags=[mol_trans,bindingdb]/checkpoints/epoch_064.ckpt",COMPLETED,mol_trans,bindingdb,,"[mol_trans,bindingdb]",,gpu3090, +0.6574236750602722,0.5444095730781555,0.6328268051147461,0.483252614736557,0.5203881859779358,0.4510641098022461,0.4510641098022461,0.7001991868019104,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-11-06_14-16-25-154180_[benchmark,tcpi2,bindingdb]/data.batch_size=32;experiment=bindingdb;preset=transformer_cpi_2;tags=[benchmark,tcpi2,bindingdb]/checkpoints/epoch_017.ckpt",COMPLETED,transformer_cpi_2,bindingdb,,"[benchmark,tcpi2,bindingdb]",,,32.0 diff --git a/configs/sweep/experiment_summary_davis.csv b/configs/sweep/experiment_summary_davis.csv new file mode 100644 index 0000000000000000000000000000000000000000..1cc9af384ea7c91cdc510d734862bbe56edc3f43 --- /dev/null +++ b/configs/sweep/experiment_summary_davis.csv @@ -0,0 +1,8 @@ +test/loss,test/auprc,test/auroc,test/f1_score,test/precision,test/recall,test/sensitivity,test/specificity,ckpt_path,job_status,preset,experiment,sweep,tags,model.optimizer.lr,local,data.batch_size +0.193635880947113,0.4078962802886963,0.8581838607788086,0.3916666805744171,0.5371428728103638,0.3081967234611511,0.3081967234611511,0.9788622260093688,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=davis;preset=deep_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_008.ckpt",COMPLETED,deep_dta,davis,slurm_test,"[benchmark,ddta,gdta]",,, +0.2221491634845733,0.3560383319854736,0.78690105676651,0.2659846544265747,0.604651153087616,0.1704917997121811,0.1704917997121811,0.9911273717880248,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=davis;preset=graph_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_054.ckpt",COMPLETED,graph_dta,davis,slurm_test,"[benchmark,ddta,gdta]",,, +0.1881016194820404,0.5213174819946289,0.8876519203186035,0.5207100510597229,0.6534653306007385,0.4327868819236755,0.4327868819236755,0.9817327857017516,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_18-08-40-028557_[sweep,benchmark]/experiment=davis;preset=hyper_attention_dti;sweep=slurm_test/checkpoints/epoch_017.ckpt",COMPLETED,hyper_attention_dti,davis,slurm_test,,,, +0.3915603458881378,0.3016493022441864,0.7910888195037842,0.085626907646656,0.6363636255264282,0.0459016375243663,0.0459016375243663,0.9979122877120972,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_18-50-02-680654_[mol_trans,benchmark]/experiment=davis;preset=mol_trans;sweep=slurm_test/checkpoints/epoch_000.ckpt",COMPLETED,mol_trans,davis,slurm_test,,,, +0.2160931378602981,0.3230183720588684,0.796424388885498,0.3271028101444244,0.5691056847572327,0.2295081913471222,0.2295081913471222,0.9861690998077391,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_20-20-11-193061_[tcpi2,kiba,davis]/experiment=davis;model.optimizer.lr=1e-06;preset=transformer_cpi_2;sweep=slurm_test;tags=[tcpi2,kiba,davis]/checkpoints/epoch_023.ckpt",COMPLETED,transformer_cpi_2,davis,slurm_test,"[tcpi2,kiba,davis]",1e-06,, +0.1689525544643402,0.5410239696502686,0.9028607606887816,0.4678111672401428,0.6770186424255371,0.3573770523071289,0.3573770523071289,0.986430048942566,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_21-25-54-050914_[deep_conv_dti,benchmark]/experiment=davis;local=gpu3090;preset=deep_conv_dti;tags=[deep_conv_dti,benchmark]/checkpoints/epoch_012.ckpt",COMPLETED,deep_conv_dti,davis,,"[deep_conv_dti,benchmark]",,gpu3090, +0.202189102768898,0.3614169955253601,0.8401678204536438,0.330232560634613,0.5680000185966492,0.2327868789434433,0.2327868789434433,0.9859081506729126,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-23_23-27-18-387556_[m_graph_dta,benchmark]/experiment=davis;local=gpu3090;preset=m_graph_dta;tags=[m_graph_dta,benchmark]/checkpoints/epoch_094.ckpt",COMPLETED,m_graph_dta,davis,,"[m_graph_dta,benchmark]",,gpu3090, diff --git a/configs/sweep/experiment_summary_kiba.csv b/configs/sweep/experiment_summary_kiba.csv new file mode 100644 index 0000000000000000000000000000000000000000..fea61408c0c735bb75f05494db511650a4beff79 --- /dev/null +++ b/configs/sweep/experiment_summary_kiba.csv @@ -0,0 +1,8 @@ +test/loss,test/auprc,test/auroc,test/f1_score,test/precision,test/recall,test/sensitivity,test/specificity,ckpt_path,job_status,preset,experiment,sweep,tags,model.optimizer.lr,local,data.batch_size +0.1525115370750427,0.735183835029602,0.9363504648208618,0.661261260509491,0.7816826701164246,0.5729898810386658,0.5729898810386658,0.984080135822296,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=kiba;preset=deep_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_041.ckpt",COMPLETED,deep_dta,kiba,slurm_test,"[benchmark,ddta,gdta]",,, +0.2281630784273147,0.4483628869056701,0.8383051156997681,0.2730720639228821,0.7176079750061035,0.1686182618141174,0.1686182618141174,0.9933990836143494,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=kiba;preset=graph_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_065.ckpt",COMPLETED,graph_dta,kiba,slurm_test,"[benchmark,ddta,gdta]",,, +0.144176036119461,0.7704265117645264,0.9501243829727172,0.6897732019424438,0.7632575631141663,0.6291959285736084,0.6291959285736084,0.980585515499115,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_18-08-40-028557_[sweep,benchmark]/experiment=kiba;preset=hyper_attention_dti;sweep=slurm_test/checkpoints/epoch_014.ckpt",COMPLETED,hyper_attention_dti,kiba,slurm_test,,,, +0.1662539541721344,0.7397991418838501,0.9369943141937256,0.6784178614616394,0.7550238966941833,0.6159250736236572,0.6159250736236572,0.9801195859909058,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_18-50-02-680654_[mol_trans,benchmark]/experiment=kiba;preset=mol_trans;sweep=slurm_test/checkpoints/epoch_008.ckpt",COMPLETED,mol_trans,kiba,slurm_test,,,, +0.2221459150314331,0.5159928798675537,0.8539965152740479,0.4716618657112121,0.612983763217926,0.3832943141460418,0.3832943141460418,0.9759260416030884,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_20-20-11-193061_[tcpi2,kiba,davis]/experiment=kiba;model.optimizer.lr=1e-06;preset=transformer_cpi_2;sweep=slurm_test;tags=[tcpi2,kiba,davis]/checkpoints/epoch_077.ckpt",COMPLETED,transformer_cpi_2,kiba,slurm_test,"[tcpi2,kiba,davis]",1e-06,, +0.1273866593837738,0.8186066746711731,0.95548677444458,0.7409326434135437,0.8289855122566223,0.6697892546653748,0.6697892546653748,0.9862545728683472,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_21-25-54-050914_[deep_conv_dti,benchmark]/experiment=kiba;local=gpu3090;preset=deep_conv_dti;tags=[deep_conv_dti,benchmark]/checkpoints/epoch_026.ckpt",COMPLETED,deep_conv_dti,kiba,,"[deep_conv_dti,benchmark]",,gpu3090, +0.1921116858720779,0.6082723736763,0.8893994092941284,0.4803767800331116,0.7285714149475098,0.3583138287067413,0.3583138287067413,0.9867205023765564,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-23_23-27-18-387556_[m_graph_dta,benchmark]/experiment=kiba;local=gpu3090;preset=m_graph_dta;tags=[m_graph_dta,benchmark]/checkpoints/epoch_246.ckpt",COMPLETED,m_graph_dta,kiba,,"[m_graph_dta,benchmark]",,gpu3090, diff --git a/configs/sweep/experiment_summary_test.csv b/configs/sweep/experiment_summary_test.csv new file mode 100644 index 0000000000000000000000000000000000000000..a1ce0b60c2e5c9e1cd678cc7a7a8867a2472d15f --- /dev/null +++ b/configs/sweep/experiment_summary_test.csv @@ -0,0 +1,4 @@ +test/loss,test/auprc,test/auroc,test/f1_score,test/precision,test/recall,test/sensitivity,test/specificity,ckpt_path,job_status,preset,experiment,sweep,tags,model.optimizer.lr,local,data.batch_size +0.193635880947113,0.4078962802886963,0.8581838607788086,0.3916666805744171,0.5371428728103638,0.3081967234611511,0.3081967234611511,0.9788622260093688,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=davis;preset=deep_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_008.ckpt",COMPLETED,deep_dta,davis,slurm_test,"[benchmark,ddta,gdta]",,, +0.2221491634845733,0.3560383319854736,0.78690105676651,0.2659846544265747,0.604651153087616,0.1704917997121811,0.1704917997121811,0.9911273717880248,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-16_14-53-13-645337_[benchmark,ddta,gdta]/experiment=davis;preset=graph_dta;sweep=slurm_test;tags=[benchmark,ddta,gdta]/checkpoints/epoch_054.ckpt",COMPLETED,graph_dta,davis,slurm_test,"[benchmark,ddta,gdta]",,, +0.1689525544643402,0.5410239696502686,0.9028607606887816,0.4678111672401428,0.6770186424255371,0.3573770523071289,0.3573770523071289,0.986430048942566,"/gpfs/work/pha/daiyunhuang/WavyWaffle/logs/train/multiruns/2023-10-20_21-25-54-050914_[deep_conv_dti,benchmark]/experiment=davis;local=gpu3090;preset=deep_conv_dti;tags=[deep_conv_dti,benchmark]/checkpoints/epoch_012.ckpt",COMPLETED,deep_conv_dti,davis,,"[deep_conv_dti,benchmark]",,gpu3090, diff --git a/configs/task/DTA.yaml b/configs/task/DTA.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0de6eeb0179fcf63608df8a62719a5f6792baf1d --- /dev/null +++ b/configs/task/DTA.yaml @@ -0,0 +1,12 @@ +task: regression +num_classes: null + +out: + _target_: torch.nn.LazyLinear + out_features: 1 + +loss: + _target_: torch.nn.MSELoss + +activation: + _target_: torch.nn.Identity \ No newline at end of file diff --git a/configs/task/DTI.yaml b/configs/task/DTI.yaml new file mode 100644 index 0000000000000000000000000000000000000000..31eafe56289ebcb591ddd4254a943015c0d3e429 --- /dev/null +++ b/configs/task/DTI.yaml @@ -0,0 +1,12 @@ +task: binary +num_classes: null + +out: + _target_: torch.nn.LazyLinear + out_features: 1 + +loss: + _target_: torch.nn.BCEWithLogitsLoss + +activation: + _target_: torch.nn.Sigmoid \ No newline at end of file diff --git a/configs/task/binary.yaml b/configs/task/binary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..31eafe56289ebcb591ddd4254a943015c0d3e429 --- /dev/null +++ b/configs/task/binary.yaml @@ -0,0 +1,12 @@ +task: binary +num_classes: null + +out: + _target_: torch.nn.LazyLinear + out_features: 1 + +loss: + _target_: torch.nn.BCEWithLogitsLoss + +activation: + _target_: torch.nn.Sigmoid \ No newline at end of file diff --git a/configs/task/multiclass.yaml b/configs/task/multiclass.yaml new file mode 100644 index 0000000000000000000000000000000000000000..686e9f9b96c98839ec4bdad32351d29aca026618 --- /dev/null +++ b/configs/task/multiclass.yaml @@ -0,0 +1,12 @@ +task: multiclass +num_classes: 3 + +out: + _target_: torch.nn.LazyLinear + out_features: ${num_classes} + +loss: + _target_: torch.nn.CrossEntropyLoss + +activation: + _target_: torch.nn.Softmax \ No newline at end of file diff --git a/configs/task/regression.yaml b/configs/task/regression.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0de6eeb0179fcf63608df8a62719a5f6792baf1d --- /dev/null +++ b/configs/task/regression.yaml @@ -0,0 +1,12 @@ +task: regression +num_classes: null + +out: + _target_: torch.nn.LazyLinear + out_features: 1 + +loss: + _target_: torch.nn.MSELoss + +activation: + _target_: torch.nn.Identity \ No newline at end of file diff --git a/configs/trainer/cpu.yaml b/configs/trainer/cpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7d6767e60c956567555980654f15e7bb673a41f --- /dev/null +++ b/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: cpu +devices: 1 diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd75aa3e6c1b2c9f38c6c666925d58069feb3e73 --- /dev/null +++ b/configs/trainer/default.yaml @@ -0,0 +1,19 @@ +_target_: lightning.Trainer + +default_root_dir: ${paths.output_dir} + +min_epochs: 1 +max_epochs: 50 + +accelerator: auto +precision: bf16 + +gradient_clip_val: 0.5 +gradient_clip_algorithm: norm + +# deterministic algorithms might make training slower but offers more reproducibility than only setting seeds +# True: use deterministic always, throwing an error on an operation that doesn't support deterministic +# warn: use deterministic when possible, throwing warnings on operations that don’t support deterministic +deterministic: warn + +inference_mode: True \ No newline at end of file diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9dc7f73609b2b2fe595d84288a3b556d8d3356c --- /dev/null +++ b/configs/trainer/gpu.yaml @@ -0,0 +1,6 @@ +defaults: + - default + +accelerator: gpu +devices: 1 +precision: 16-mixed \ No newline at end of file diff --git a/configs/webserver_inference.yaml b/configs/webserver_inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..85ab0acd3dd0d1ecdcd48f31ff6c5e1b734aeb2c --- /dev/null +++ b/configs/webserver_inference.yaml @@ -0,0 +1,31 @@ +# @package _global_ +defaults: + - model: dti_model # fixed for web server version + - task: null + - data: dti_data # fixed for web server version + - callbacks: tqdm_progress_bar + - trainer: default + - paths: default + - extras: null + - hydra: null + - _self_ + - preset: null + - experiment: null + - sweep: null + - debug: null + - optional local: default + - override model/metrics: null + +job_name: "webserver_inference" + +tags: null + +# passing checkpoint path is necessary for prediction +ckpt_path: ??? + +paths: + output_dir: null + work_dir: null + +data: + num_workers: 8 \ No newline at end of file