Niksa Praljak
commited on
Commit
·
7f6f0e9
1
Parent(s):
013cb42
include the ProteoScribe cuda check
Browse files
run_ProteoScribe_sample.py
CHANGED
@@ -150,7 +150,11 @@ if __name__ == '__main__':
|
|
150 |
# Load and convert JSON config
|
151 |
config_dict = load_json_config(config_args_parser.json_path)
|
152 |
config_args = convert_to_namespace(config_dict)
|
153 |
-
|
|
|
|
|
|
|
|
|
154 |
# load test dataset
|
155 |
embedding_dataset = torch.load(config_args_parser.input_path)
|
156 |
|
|
|
150 |
# Load and convert JSON config
|
151 |
config_dict = load_json_config(config_args_parser.json_path)
|
152 |
config_args = convert_to_namespace(config_dict)
|
153 |
+
|
154 |
+
# Set device if not specified in config
|
155 |
+
if not hasattr(config_args, 'device'):
|
156 |
+
config_args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
157 |
+
|
158 |
# load test dataset
|
159 |
embedding_dataset = torch.load(config_args_parser.input_path)
|
160 |
|