Niksa Praljak commited on
Commit
7f6f0e9
·
1 Parent(s): 013cb42

include the ProteoScribe cuda check

Browse files
Files changed (1) hide show
  1. run_ProteoScribe_sample.py +5 -1
run_ProteoScribe_sample.py CHANGED
@@ -150,7 +150,11 @@ if __name__ == '__main__':
150
  # Load and convert JSON config
151
  config_dict = load_json_config(config_args_parser.json_path)
152
  config_args = convert_to_namespace(config_dict)
153
-
 
 
 
 
154
  # load test dataset
155
  embedding_dataset = torch.load(config_args_parser.input_path)
156
 
 
150
  # Load and convert JSON config
151
  config_dict = load_json_config(config_args_parser.json_path)
152
  config_args = convert_to_namespace(config_dict)
153
+
154
+ # Set device if not specified in config
155
+ if not hasattr(config_args, 'device'):
156
+ config_args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
157
+
158
  # load test dataset
159
  embedding_dataset = torch.load(config_args_parser.input_path)
160