# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from abc import ABC import builtins import json import os from copy import deepcopy from functools import partial from typing import List, Dict, Tuple, Union import pandas as pd from agent import settings from agent.settings import flow_logger, DEBUG _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params" _USER_FEEDED_PARAMS = "_user_feeded_params" _IS_RAW_CONF = "_is_raw_conf" class ComponentParamBase(ABC): def __init__(self): self.output_var_name = "output" self.message_history_window_size = 22 def set_name(self, name: str): self._name = name return self def check(self): raise NotImplementedError("Parameter Object should be checked.") @classmethod def _get_or_init_deprecated_params_set(cls): if not hasattr(cls, _DEPRECATED_PARAMS): setattr(cls, _DEPRECATED_PARAMS, set()) return getattr(cls, _DEPRECATED_PARAMS) def _get_or_init_feeded_deprecated_params_set(self, conf=None): if not hasattr(self, _FEEDED_DEPRECATED_PARAMS): if conf is None: setattr(self, _FEEDED_DEPRECATED_PARAMS, set()) else: setattr( self, _FEEDED_DEPRECATED_PARAMS, set(conf[_FEEDED_DEPRECATED_PARAMS]), ) return getattr(self, _FEEDED_DEPRECATED_PARAMS) def _get_or_init_user_feeded_params_set(self, conf=None): if not hasattr(self, _USER_FEEDED_PARAMS): if conf is None: setattr(self, _USER_FEEDED_PARAMS, set()) else: setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS])) return getattr(self, _USER_FEEDED_PARAMS) def get_user_feeded(self): return self._get_or_init_user_feeded_params_set() def get_feeded_deprecated_params(self): return self._get_or_init_feeded_deprecated_params_set() @property def _deprecated_params_set(self): return {name: True for name in self.get_feeded_deprecated_params()} def __str__(self): return json.dumps(self.as_dict(), ensure_ascii=False) def as_dict(self): def _recursive_convert_obj_to_dict(obj): ret_dict = {} for attr_name in list(obj.__dict__): if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]: continue # get attr attr = getattr(obj, attr_name) if isinstance(attr, pd.DataFrame): ret_dict[attr_name] = attr.to_dict() continue if attr and type(attr).__name__ not in dir(builtins): ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr) else: ret_dict[attr_name] = attr return ret_dict return _recursive_convert_obj_to_dict(self) def update(self, conf, allow_redundant=False): update_from_raw_conf = conf.get(_IS_RAW_CONF, True) if update_from_raw_conf: deprecated_params_set = self._get_or_init_deprecated_params_set() feeded_deprecated_params_set = ( self._get_or_init_feeded_deprecated_params_set() ) user_feeded_params_set = self._get_or_init_user_feeded_params_set() setattr(self, _IS_RAW_CONF, False) else: feeded_deprecated_params_set = ( self._get_or_init_feeded_deprecated_params_set(conf) ) user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf) def _recursive_update_param(param, config, depth, prefix): if depth > settings.PARAM_MAXDEPTH: raise ValueError("Param define nesting too deep!!!, can not parse it") inst_variables = param.__dict__ redundant_attrs = [] for config_key, config_value in config.items(): # redundant attr if config_key not in inst_variables: if not update_from_raw_conf and config_key.startswith("_"): setattr(param, config_key, config_value) else: setattr(param, config_key, config_value) # redundant_attrs.append(config_key) continue full_config_key = f"{prefix}{config_key}" if update_from_raw_conf: # add user feeded params user_feeded_params_set.add(full_config_key) # update user feeded deprecated param set if full_config_key in deprecated_params_set: feeded_deprecated_params_set.add(full_config_key) # supported attr attr = getattr(param, config_key) if type(attr).__name__ in dir(builtins) or attr is None: setattr(param, config_key, config_value) else: # recursive set obj attr sub_params = _recursive_update_param( attr, config_value, depth + 1, prefix=f"{prefix}{config_key}." ) setattr(param, config_key, sub_params) if not allow_redundant and redundant_attrs: raise ValueError( f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`" ) return param return _recursive_update_param(param=self, config=conf, depth=0, prefix="") def extract_not_builtin(self): def _get_not_builtin_types(obj): ret_dict = {} for variable in obj.__dict__: attr = getattr(obj, variable) if attr and type(attr).__name__ not in dir(builtins): ret_dict[variable] = _get_not_builtin_types(attr) return ret_dict return _get_not_builtin_types(self) def validate(self): self.builtin_types = dir(builtins) self.func = { "ge": self._greater_equal_than, "le": self._less_equal_than, "in": self._in, "not_in": self._not_in, "range": self._range, } home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__))) param_validation_path_prefix = home_dir + "/param_validation/" param_name = type(self).__name__ param_validation_path = "/".join( [param_validation_path_prefix, param_name + ".json"] ) validation_json = None try: with open(param_validation_path, "r") as fin: validation_json = json.loads(fin.read()) except BaseException: return self._validate_param(self, validation_json) def _validate_param(self, param_obj, validation_json): default_section = type(param_obj).__name__ var_list = param_obj.__dict__ for variable in var_list: attr = getattr(param_obj, variable) if type(attr).__name__ in self.builtin_types or attr is None: if variable not in validation_json: continue validation_dict = validation_json[default_section][variable] value = getattr(param_obj, variable) value_legal = False for op_type in validation_dict: if self.func[op_type](value, validation_dict[op_type]): value_legal = True break if not value_legal: raise ValueError( "Plase check runtime conf, {} = {} does not match user-parameter restriction".format( variable, value ) ) elif variable in validation_json: self._validate_param(attr, validation_json) @staticmethod def check_string(param, descr): if type(param).__name__ not in ["str"]: raise ValueError( descr + " {} not supported, should be string type".format(param) ) @staticmethod def check_empty(param, descr): if not param: raise ValueError( descr + " does not support empty value." ) @staticmethod def check_positive_integer(param, descr): if type(param).__name__ not in ["int", "long"] or param <= 0: raise ValueError( descr + " {} not supported, should be positive integer".format(param) ) @staticmethod def check_positive_number(param, descr): if type(param).__name__ not in ["float", "int", "long"] or param <= 0: raise ValueError( descr + " {} not supported, should be positive numeric".format(param) ) @staticmethod def check_nonnegative_number(param, descr): if type(param).__name__ not in ["float", "int", "long"] or param < 0: raise ValueError( descr + " {} not supported, should be non-negative numeric".format(param) ) @staticmethod def check_decimal_float(param, descr): if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1: raise ValueError( descr + " {} not supported, should be a float number in range [0, 1]".format( param ) ) @staticmethod def check_boolean(param, descr): if type(param).__name__ != "bool": raise ValueError( descr + " {} not supported, should be bool type".format(param) ) @staticmethod def check_open_unit_interval(param, descr): if type(param).__name__ not in ["float"] or param <= 0 or param >= 1: raise ValueError( descr + " should be a numeric number between 0 and 1 exclusively" ) @staticmethod def check_valid_value(param, descr, valid_values): if param not in valid_values: raise ValueError( descr + " {} is not supported, it should be in {}".format(param, valid_values) ) @staticmethod def check_defined_type(param, descr, types): if type(param).__name__ not in types: raise ValueError( descr + " {} not supported, should be one of {}".format(param, types) ) @staticmethod def check_and_change_lower(param, valid_list, descr=""): if type(param).__name__ != "str": raise ValueError( descr + " {} not supported, should be one of {}".format(param, valid_list) ) lower_param = param.lower() if lower_param in valid_list: return lower_param else: raise ValueError( descr + " {} not supported, should be one of {}".format(param, valid_list) ) @staticmethod def _greater_equal_than(value, limit): return value >= limit - settings.FLOAT_ZERO @staticmethod def _less_equal_than(value, limit): return value <= limit + settings.FLOAT_ZERO @staticmethod def _range(value, ranges): in_range = False for left_limit, right_limit in ranges: if ( left_limit - settings.FLOAT_ZERO <= value <= right_limit + settings.FLOAT_ZERO ): in_range = True break return in_range @staticmethod def _in(value, right_value_list): return value in right_value_list @staticmethod def _not_in(value, wrong_value_list): return value not in wrong_value_list def _warn_deprecated_param(self, param_name, descr): if self._deprecated_params_set.get(param_name): flow_logger.warning( f"{descr} {param_name} is deprecated and ignored in this version." ) def _warn_to_deprecate_param(self, param_name, descr, new_param): if self._deprecated_params_set.get(param_name): flow_logger.warning( f"{descr} {param_name} will be deprecated in future release; " f"please use {new_param} instead." ) return True return False class ComponentBase(ABC): component_name: str def __str__(self): """ { "component_name": "Begin", "params": {} } """ return """{{ "component_name": "{}", "params": {} }}""".format(self.component_name, self._param ) def __init__(self, canvas, id, param: ComponentParamBase): self._canvas = canvas self._id = id self._param = param self._param.check() def run(self, history, **kwargs): flow_logger.info("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), json.dumps(kwargs, ensure_ascii=False))) try: res = self._run(history, **kwargs) self.set_output(res) except Exception as e: self.set_output(pd.DataFrame([{"content": str(e)}])) raise e return res def _run(self, history, **kwargs): raise NotImplementedError() def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]: o = getattr(self._param, self._param.output_var_name) if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): if not isinstance(o, list): o = [o] o = pd.DataFrame(o) if allow_partial or not isinstance(o, partial): if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): return pd.DataFrame(o if isinstance(o, list) else [o]) return self._param.output_var_name, o outs = None for oo in o(): if not isinstance(oo, pd.DataFrame): outs = pd.DataFrame(oo if isinstance(oo, list) else [oo]) else: outs = oo return self._param.output_var_name, outs def reset(self): setattr(self._param, self._param.output_var_name, None) def set_output(self, v: pd.DataFrame): setattr(self._param, self._param.output_var_name, v) def get_input(self): upstream_outs = [] reversed_cpnts = [] if len(self._canvas.path) > 1: reversed_cpnts.extend(self._canvas.path[-2]) reversed_cpnts.extend(self._canvas.path[-1]) if DEBUG: print(self.component_name, reversed_cpnts[::-1]) for u in reversed_cpnts[::-1]: if self.get_component_name(u) in ["switch"]: continue if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1] if o is not None: upstream_outs.append(o) continue if u not in self._canvas.get_component(self._id)["upstream"]: continue if self.component_name.lower().find("switch") < 0 \ and self.get_component_name(u) in ["relevant", "categorize"]: continue if u.lower().find("answer") >= 0: for r, c in self._canvas.history[::-1]: if r == "user": upstream_outs.append(pd.DataFrame([{"content": c}])) break break if self.component_name.lower().find("answer") >= 0: if self.get_component_name(u) in ["relevant"]: continue else: o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1] if o is not None: upstream_outs.append(o) break if upstream_outs: df = pd.concat(upstream_outs, ignore_index=True) if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True) return df return pd.DataFrame() def get_stream_input(self): reversed_cpnts = [] if len(self._canvas.path) > 1: reversed_cpnts.extend(self._canvas.path[-2]) reversed_cpnts.extend(self._canvas.path[-1]) for u in reversed_cpnts[::-1]: if self.get_component_name(u) in ["switch", "answer"]: continue return self._canvas.get_component(u)["obj"].output()[1] @staticmethod def be_output(v): return pd.DataFrame([{"content": v}]) def get_component_name(self, cpn_id): return self._canvas.get_component(cpn_id)["obj"].component_name.lower()