Spaces:
Sleeping
Sleeping
File size: 2,800 Bytes
4b9251f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import configparser
import ast
class Config:
def __init__(self, config_data):
self.config_data = config_data
def __getattr__(self, name):
return self.config_data.get(name, {})
def __setattr__(self, name, value):
if name == 'config_data':
super().__setattr__(name, value)
else:
self.config_data[name] = value
def __repr__(self):
return f"Config(\n{'pipeline':>12}: {self.pipeline}\n{'llm':>12}: {self.llm}\n)"
@classmethod
def read_ini(cls, file_path):
parser = configparser.ConfigParser()
parser.read(file_path)
return cls({"pipeline": cls._parse_pipeline_section(parser["pipeline"]),
"llm": cls._parse_llm_section(parser["llm"])})
@staticmethod
def _parse_llm_section(section):
parsed_data = {}
for key, value in section.items():
try:
parsed_data[key] = ast.literal_eval(value)
except ValueError:
parsed_data[key] = value
return parsed_data
@staticmethod
def _parse_pipeline_section(section):
eval_func = {
"words_per_batch": int,
"preprocess": lambda x: x.split(", ") if x.split(", ") != ["None"] else []
}
parsed_data = {}
for key, value in section.items():
parsed_data[key] = eval_func.get(key, str)(value)
return parsed_data
def serialize(self, save_path: str=None):
"""Convert Config object to .ini file. If save_path is not specified, return string"""
serialized_config = ''
for section in self.config_data:
serialized_config += f"[{section}]\n"
for key, value in self.config_data[section].items():
# turn list back to str
if isinstance(value, list):
value = ", ".join(value)
# don't serialize the api key
if key == "openai_api_key":
value = None
serialized_config += f"{key} = {value}\n"
serialized_config += "\n"
if save_path:
with open(save_path, 'w') as f:
f.write(serialized_config)
else:
return serialized_config
if __name__ == "__main__":
# example usage
config_file = "/Users/johaunh/Documents/PhD/Projects/Text2KG/config/config.ini"
config = Config.read_ini(config_file)
# access pipeline parameters
print("Pipeline Parameters:")
for k, v in config.pipeline.items():
print(f"{k}: {v}")
# access LLM parameters
print("\nLLM Parameters:")
for k, v in config.llm.items():
print(f"{k}: {v}")
|