Spaces:
Sleeping
Sleeping
File size: 2,646 Bytes
7a9a856 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import sys
import yaml
import pandas as pd
from utils import *
def load_config(filename):
with open(filename, 'r') as stream:
try:
return yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
def save_results(dsm_result, spm_result, occurrence_matrix, dsm_result_path, spm_result_path, occurrence_matrix_path):
dsm_result.to_csv(dsm_result_path, index=False)
spm_result.to_csv(spm_result_path, index=False)
occurrence_matrix.to_csv(occurrence_matrix_path, index=False)
def main():
if len(sys.argv) > 1:
config_file = 'config.yaml'
command_args = sys.argv[1:]
overridden_params = {}
for arg in command_args:
if '=' in arg:
key, value = arg.split('=')
overridden_params[key] = value
else:
print(f"Ignoring invalid argument: {arg}")
config = load_config(config_file)
# SPM parameters
spm_params = config.get('SPM', {})
for key, value in overridden_params.items():
if key in spm_params:
spm_params[key] = value
spm_result, occurrence_matrix = SPM(spm_params)
# DSM parameters
dsm_params = config.get('DSM', {})
for key, value in overridden_params.items():
if key in dsm_params:
dsm_params[key] = value
ptrn_left, ptrn_right, ptrn_both_left, ptrn_both_right, dsm_result = DSM(dsm_params)
# Fetching paths from config
dsm_result_path = config.get('dsm_result_path')
spm_result_path = config.get('spm_result_path')
occurrence_matrix_path = config.get('occurrence_matrix_path')
# Saving results
save_results(dsm_result, spm_result, occurrence_matrix, dsm_result_path, spm_result_path, occurrence_matrix_path)
else:
config_file = 'config.yaml'
config = load_config(config_file)
# SPM parameters
spm_params = config.get('SPM', {})
spm_result, occurrence_matrix = SPM(spm_params)
# DSM parameters
dsm_params = config.get('DSM', {})
ptrn_left, ptrn_right, ptrn_both_left, ptrn_both_right, dsm_result = DSM(dsm_params)
# Fetching paths from config
dsm_result_path = config.get('dsm_result_path')
spm_result_path = config.get('spm_result_path')
occurrence_matrix_path = config.get('occurrence_matrix_path')
# Saving results
save_results(dsm_result, spm_result, occurrence_matrix, dsm_result_path, spm_result_path, occurrence_matrix_path)
if __name__ == "__main__":
main()
|