Spaces:
Paused
Paused
# | |
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
from abc import ABC | |
import builtins | |
import json | |
import os | |
from copy import deepcopy | |
from functools import partial | |
from typing import List, Dict, Tuple, Union | |
import pandas as pd | |
from agent import settings | |
from agent.settings import flow_logger, DEBUG | |
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" | |
_DEPRECATED_PARAMS = "_deprecated_params" | |
_USER_FEEDED_PARAMS = "_user_feeded_params" | |
_IS_RAW_CONF = "_is_raw_conf" | |
class ComponentParamBase(ABC): | |
def __init__(self): | |
self.output_var_name = "output" | |
self.message_history_window_size = 22 | |
def set_name(self, name: str): | |
self._name = name | |
return self | |
def check(self): | |
raise NotImplementedError("Parameter Object should be checked.") | |
def _get_or_init_deprecated_params_set(cls): | |
if not hasattr(cls, _DEPRECATED_PARAMS): | |
setattr(cls, _DEPRECATED_PARAMS, set()) | |
return getattr(cls, _DEPRECATED_PARAMS) | |
def _get_or_init_feeded_deprecated_params_set(self, conf=None): | |
if not hasattr(self, _FEEDED_DEPRECATED_PARAMS): | |
if conf is None: | |
setattr(self, _FEEDED_DEPRECATED_PARAMS, set()) | |
else: | |
setattr( | |
self, | |
_FEEDED_DEPRECATED_PARAMS, | |
set(conf[_FEEDED_DEPRECATED_PARAMS]), | |
) | |
return getattr(self, _FEEDED_DEPRECATED_PARAMS) | |
def _get_or_init_user_feeded_params_set(self, conf=None): | |
if not hasattr(self, _USER_FEEDED_PARAMS): | |
if conf is None: | |
setattr(self, _USER_FEEDED_PARAMS, set()) | |
else: | |
setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS])) | |
return getattr(self, _USER_FEEDED_PARAMS) | |
def get_user_feeded(self): | |
return self._get_or_init_user_feeded_params_set() | |
def get_feeded_deprecated_params(self): | |
return self._get_or_init_feeded_deprecated_params_set() | |
def _deprecated_params_set(self): | |
return {name: True for name in self.get_feeded_deprecated_params()} | |
def __str__(self): | |
return json.dumps(self.as_dict(), ensure_ascii=False) | |
def as_dict(self): | |
def _recursive_convert_obj_to_dict(obj): | |
ret_dict = {} | |
for attr_name in list(obj.__dict__): | |
if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]: | |
continue | |
# get attr | |
attr = getattr(obj, attr_name) | |
if isinstance(attr, pd.DataFrame): | |
ret_dict[attr_name] = attr.to_dict() | |
continue | |
if attr and type(attr).__name__ not in dir(builtins): | |
ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr) | |
else: | |
ret_dict[attr_name] = attr | |
return ret_dict | |
return _recursive_convert_obj_to_dict(self) | |
def update(self, conf, allow_redundant=False): | |
update_from_raw_conf = conf.get(_IS_RAW_CONF, True) | |
if update_from_raw_conf: | |
deprecated_params_set = self._get_or_init_deprecated_params_set() | |
feeded_deprecated_params_set = ( | |
self._get_or_init_feeded_deprecated_params_set() | |
) | |
user_feeded_params_set = self._get_or_init_user_feeded_params_set() | |
setattr(self, _IS_RAW_CONF, False) | |
else: | |
feeded_deprecated_params_set = ( | |
self._get_or_init_feeded_deprecated_params_set(conf) | |
) | |
user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf) | |
def _recursive_update_param(param, config, depth, prefix): | |
if depth > settings.PARAM_MAXDEPTH: | |
raise ValueError("Param define nesting too deep!!!, can not parse it") | |
inst_variables = param.__dict__ | |
redundant_attrs = [] | |
for config_key, config_value in config.items(): | |
# redundant attr | |
if config_key not in inst_variables: | |
if not update_from_raw_conf and config_key.startswith("_"): | |
setattr(param, config_key, config_value) | |
else: | |
setattr(param, config_key, config_value) | |
# redundant_attrs.append(config_key) | |
continue | |
full_config_key = f"{prefix}{config_key}" | |
if update_from_raw_conf: | |
# add user feeded params | |
user_feeded_params_set.add(full_config_key) | |
# update user feeded deprecated param set | |
if full_config_key in deprecated_params_set: | |
feeded_deprecated_params_set.add(full_config_key) | |
# supported attr | |
attr = getattr(param, config_key) | |
if type(attr).__name__ in dir(builtins) or attr is None: | |
setattr(param, config_key, config_value) | |
else: | |
# recursive set obj attr | |
sub_params = _recursive_update_param( | |
attr, config_value, depth + 1, prefix=f"{prefix}{config_key}." | |
) | |
setattr(param, config_key, sub_params) | |
if not allow_redundant and redundant_attrs: | |
raise ValueError( | |
f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`" | |
) | |
return param | |
return _recursive_update_param(param=self, config=conf, depth=0, prefix="") | |
def extract_not_builtin(self): | |
def _get_not_builtin_types(obj): | |
ret_dict = {} | |
for variable in obj.__dict__: | |
attr = getattr(obj, variable) | |
if attr and type(attr).__name__ not in dir(builtins): | |
ret_dict[variable] = _get_not_builtin_types(attr) | |
return ret_dict | |
return _get_not_builtin_types(self) | |
def validate(self): | |
self.builtin_types = dir(builtins) | |
self.func = { | |
"ge": self._greater_equal_than, | |
"le": self._less_equal_than, | |
"in": self._in, | |
"not_in": self._not_in, | |
"range": self._range, | |
} | |
home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__))) | |
param_validation_path_prefix = home_dir + "/param_validation/" | |
param_name = type(self).__name__ | |
param_validation_path = "/".join( | |
[param_validation_path_prefix, param_name + ".json"] | |
) | |
validation_json = None | |
try: | |
with open(param_validation_path, "r") as fin: | |
validation_json = json.loads(fin.read()) | |
except BaseException: | |
return | |
self._validate_param(self, validation_json) | |
def _validate_param(self, param_obj, validation_json): | |
default_section = type(param_obj).__name__ | |
var_list = param_obj.__dict__ | |
for variable in var_list: | |
attr = getattr(param_obj, variable) | |
if type(attr).__name__ in self.builtin_types or attr is None: | |
if variable not in validation_json: | |
continue | |
validation_dict = validation_json[default_section][variable] | |
value = getattr(param_obj, variable) | |
value_legal = False | |
for op_type in validation_dict: | |
if self.func[op_type](value, validation_dict[op_type]): | |
value_legal = True | |
break | |
if not value_legal: | |
raise ValueError( | |
"Plase check runtime conf, {} = {} does not match user-parameter restriction".format( | |
variable, value | |
) | |
) | |
elif variable in validation_json: | |
self._validate_param(attr, validation_json) | |
def check_string(param, descr): | |
if type(param).__name__ not in ["str"]: | |
raise ValueError( | |
descr + " {} not supported, should be string type".format(param) | |
) | |
def check_empty(param, descr): | |
if not param: | |
raise ValueError( | |
descr + " does not support empty value." | |
) | |
def check_positive_integer(param, descr): | |
if type(param).__name__ not in ["int", "long"] or param <= 0: | |
raise ValueError( | |
descr + " {} not supported, should be positive integer".format(param) | |
) | |
def check_positive_number(param, descr): | |
if type(param).__name__ not in ["float", "int", "long"] or param <= 0: | |
raise ValueError( | |
descr + " {} not supported, should be positive numeric".format(param) | |
) | |
def check_nonnegative_number(param, descr): | |
if type(param).__name__ not in ["float", "int", "long"] or param < 0: | |
raise ValueError( | |
descr | |
+ " {} not supported, should be non-negative numeric".format(param) | |
) | |
def check_decimal_float(param, descr): | |
if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1: | |
raise ValueError( | |
descr | |
+ " {} not supported, should be a float number in range [0, 1]".format( | |
param | |
) | |
) | |
def check_boolean(param, descr): | |
if type(param).__name__ != "bool": | |
raise ValueError( | |
descr + " {} not supported, should be bool type".format(param) | |
) | |
def check_open_unit_interval(param, descr): | |
if type(param).__name__ not in ["float"] or param <= 0 or param >= 1: | |
raise ValueError( | |
descr + " should be a numeric number between 0 and 1 exclusively" | |
) | |
def check_valid_value(param, descr, valid_values): | |
if param not in valid_values: | |
raise ValueError( | |
descr | |
+ " {} is not supported, it should be in {}".format(param, valid_values) | |
) | |
def check_defined_type(param, descr, types): | |
if type(param).__name__ not in types: | |
raise ValueError( | |
descr + " {} not supported, should be one of {}".format(param, types) | |
) | |
def check_and_change_lower(param, valid_list, descr=""): | |
if type(param).__name__ != "str": | |
raise ValueError( | |
descr | |
+ " {} not supported, should be one of {}".format(param, valid_list) | |
) | |
lower_param = param.lower() | |
if lower_param in valid_list: | |
return lower_param | |
else: | |
raise ValueError( | |
descr | |
+ " {} not supported, should be one of {}".format(param, valid_list) | |
) | |
def _greater_equal_than(value, limit): | |
return value >= limit - settings.FLOAT_ZERO | |
def _less_equal_than(value, limit): | |
return value <= limit + settings.FLOAT_ZERO | |
def _range(value, ranges): | |
in_range = False | |
for left_limit, right_limit in ranges: | |
if ( | |
left_limit - settings.FLOAT_ZERO | |
<= value | |
<= right_limit + settings.FLOAT_ZERO | |
): | |
in_range = True | |
break | |
return in_range | |
def _in(value, right_value_list): | |
return value in right_value_list | |
def _not_in(value, wrong_value_list): | |
return value not in wrong_value_list | |
def _warn_deprecated_param(self, param_name, descr): | |
if self._deprecated_params_set.get(param_name): | |
flow_logger.warning( | |
f"{descr} {param_name} is deprecated and ignored in this version." | |
) | |
def _warn_to_deprecate_param(self, param_name, descr, new_param): | |
if self._deprecated_params_set.get(param_name): | |
flow_logger.warning( | |
f"{descr} {param_name} will be deprecated in future release; " | |
f"please use {new_param} instead." | |
) | |
return True | |
return False | |
class ComponentBase(ABC): | |
component_name: str | |
def __str__(self): | |
""" | |
{ | |
"component_name": "Begin", | |
"params": {} | |
} | |
""" | |
return """{{ | |
"component_name": "{}", | |
"params": {} | |
}}""".format(self.component_name, | |
self._param | |
) | |
def __init__(self, canvas, id, param: ComponentParamBase): | |
self._canvas = canvas | |
self._id = id | |
self._param = param | |
self._param.check() | |
def run(self, history, **kwargs): | |
flow_logger.info("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), | |
json.dumps(kwargs, ensure_ascii=False))) | |
try: | |
res = self._run(history, **kwargs) | |
self.set_output(res) | |
except Exception as e: | |
self.set_output(pd.DataFrame([{"content": str(e)}])) | |
raise e | |
return res | |
def _run(self, history, **kwargs): | |
raise NotImplementedError() | |
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]: | |
o = getattr(self._param, self._param.output_var_name) | |
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): | |
if not isinstance(o, list): o = [o] | |
o = pd.DataFrame(o) | |
if allow_partial or not isinstance(o, partial): | |
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): | |
return pd.DataFrame(o if isinstance(o, list) else [o]) | |
return self._param.output_var_name, o | |
outs = None | |
for oo in o(): | |
if not isinstance(oo, pd.DataFrame): | |
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo]) | |
else: outs = oo | |
return self._param.output_var_name, outs | |
def reset(self): | |
setattr(self._param, self._param.output_var_name, None) | |
def set_output(self, v: pd.DataFrame): | |
setattr(self._param, self._param.output_var_name, v) | |
def get_input(self): | |
upstream_outs = [] | |
reversed_cpnts = [] | |
if len(self._canvas.path) > 1: | |
reversed_cpnts.extend(self._canvas.path[-2]) | |
reversed_cpnts.extend(self._canvas.path[-1]) | |
if DEBUG: print(self.component_name, reversed_cpnts[::-1]) | |
for u in reversed_cpnts[::-1]: | |
if self.get_component_name(u) in ["switch"]: continue | |
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": | |
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1] | |
if o is not None: | |
upstream_outs.append(o) | |
continue | |
if u not in self._canvas.get_component(self._id)["upstream"]: continue | |
if self.component_name.lower().find("switch") < 0 \ | |
and self.get_component_name(u) in ["relevant", "categorize"]: | |
continue | |
if u.lower().find("answer") >= 0: | |
for r, c in self._canvas.history[::-1]: | |
if r == "user": | |
upstream_outs.append(pd.DataFrame([{"content": c}])) | |
break | |
break | |
if self.component_name.lower().find("answer") >= 0: | |
if self.get_component_name(u) in ["relevant"]: | |
continue | |
else: | |
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1] | |
if o is not None: | |
upstream_outs.append(o) | |
break | |
if upstream_outs: | |
df = pd.concat(upstream_outs, ignore_index=True) | |
if "content" in df: | |
df = df.drop_duplicates(subset=['content']).reset_index(drop=True) | |
return df | |
return pd.DataFrame() | |
def get_stream_input(self): | |
reversed_cpnts = [] | |
if len(self._canvas.path) > 1: | |
reversed_cpnts.extend(self._canvas.path[-2]) | |
reversed_cpnts.extend(self._canvas.path[-1]) | |
for u in reversed_cpnts[::-1]: | |
if self.get_component_name(u) in ["switch", "answer"]: continue | |
return self._canvas.get_component(u)["obj"].output()[1] | |
def be_output(v): | |
return pd.DataFrame([{"content": v}]) | |
def get_component_name(self, cpn_id): | |
return self._canvas.get_component(cpn_id)["obj"].component_name.lower() | |