# my_custom_olmoe/configuration_custom.py | |
# 注意:根据你的 transformers 版本,导入官方 OLMoE 配置的路径可能需要调整 | |
from transformers.models.olmoe.configuration_olmoe import OlmoeConfig | |
class DenseBackwardOLMoEConfig(OlmoeConfig): | |
model_type = "DenseBackward_olmoe" # 这里覆盖 model_type 字段,便于后续识别 | |
# 添加auto_map用于支持AutoClass | |
auto_map = { | |
"AutoConfig": "configuration_custom.DenseBackwardOLMoEConfig", | |
"AutoModelForCausalLM": "modeling_custom.DenseBackwardOLMoEForCausalLM" | |
} | |
def __init__(self, model_marker="DenseBackward_olmoe_marker", **kwargs): | |
super().__init__(**kwargs) | |
self.model_marker = model_marker | |
self.intermediate_size= 1024 | |
self.torch_dtype= "bfloat16" | |
#test | |
def main(): | |
config = DenseBackwardOLMoEConfig(model_marker="DenseBackward_olmoe_marker", | |
torch_dtype="bfloat16") | |
print(config) | |
if __name__ == "__main__": | |
main() |