ToolPlanner

Paper Link

ToolPlanner: A Tool Augmented LLM for Multi Granularity Instructions with Path Planning and Feedback


目录

Requirement

accelerate==0.24.0
datasets==2.13.0
deepspeed==0.9.2
Flask==1.1.2
Flask_Cors==4.0.0
huggingface_hub==0.16.4
jsonlines==3.1.0
nltk==3.7
numpy==1.24.3
openai==0.27.7
pandas==2.0.3
peft==0.6.0.dev0
psutil==5.8.0
pydantic==1.10.8
pygraphviz==1.11
PyYAML==6.0
PyYAML==6.0.1
Requests==2.31.0
scikit_learn==1.0.2
scipy==1.11.4
sentence_transformers==2.2.2
tenacity==8.2.3
termcolor==2.4.0
torch==2.0.1
tqdm==4.65.0
transformers==4.28.1
trl==0.7.3.dev0

Data

path data description
[/data/category/dataset] MGToolBench: pairwise_responses
/data/category/answer MGToolBench: Multi-Level Instruction Split
/data/category/coarse_instruction Self-Instruct Data: multi-granularity instructions
/data/test_sample Test Sample: test dataset
[/data/category/toolenv] Tool Environment: Tools, APIs, and their documentation.
[/data/category/inference] Output: solution trees path
/data/category/converted_answer Output: converted_answer path
/data/category/retrieval/G3_category Supplementary: Category & Tool & API Name
/data/retrieval/G3_clear Supplementary: corpus for seperate retriever

Download Data and Checkpoints

download these data and unzip them.

path data description data name url
[/data/category/dataset] MGToolBench: pairwise_responses G3_1107_gensample_Reward_pair.json https://huggingface.co/datasets/wuqinzhuo/ToolPlanner
[/data/category/toolenv] Tool Environment: Tools, APIs, and their documentation. toolenv.zip https://huggingface.co/datasets/wuqinzhuo/ToolPlanner
[/data/category/inference] Output: solution trees path inference.zip https://huggingface.co/datasets/wuqinzhuo/ToolPlanner
path model description model name url
[ToolPlanner root path] Stage1 sft model ToolPlanner_Stage1_1020 https://huggingface.co/wuqinzhuo/ToolPlanner_Stage1_1020
[ToolPlanner root path] Stage1 sft model ToolPlanner_Stage2_1107 https://huggingface.co/wuqinzhuo/ToolPlanner_Stage2_1107/
[ToolPlanner root path] Baseline ToolLLaMA ToolLLaMA-7b https://github.com/OpenBMB/ToolBench
[ToolPlanner root path] Retrivel model for test, using MGToolBench data model_1122_G3_tag_trace_multilevel https://huggingface.co/wuqinzhuo/model_1122_G3_tag_trace_multilevel
[ToolPlanner root path] Retrivel model for test, using ToolBench data retriever_model_G3_clear https://huggingface.co/wuqinzhuo/retriever_model_G3_clear

Model

Install

pip install -r requirements.txt

Train ToolPlanner, Stage 1 SFT

Script

bash scripts/category/train_model_1020_stage1.sh 

Code

export PYTHONPATH=./
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

torchrun --nproc_per_node=8 --master_port=20001 toolbench/train/train_long_seq.py \
    --model_name_or_path ToolLLaMA-7b  \
    --data_path  data/category/answer/G3_plan_gen_train_1020_G3_3tag_whole_prefixTagTraceAll.json \
    --eval_data_path  data/category/answer/G3_plan_gen_eval_1020_G3_3tag_whole_prefixTagTraceAll.json \
    --conv_template tool-llama-single-round \
    --bf16 True \
    --output_dir ToolPlanner_Stage1 \
    --num_train_epochs 2 \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --evaluation_strategy "epoch" \
    --prediction_loss_only \
    --save_strategy "epoch" \
    --save_total_limit 8 \
    --learning_rate 5e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.04 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --tf32 True \
    --model_max_length 8192 \
    --gradient_checkpointing True \
    --lazy_preprocess True \
    --report_to none

Train ToolPlanner, Stage 2 Reinforcement Learning

Script

bash scripts/category/train_model_1107_stage2.sh 

Code

export PYTHONPATH=./
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

export MODEL_PATH="ToolPlanner_Stage1_1020"
export SAVE_PATH="ToolPlanner_Stage2"
export DATA_PATH="data/category/dataset/G3_1107_gensample_Reward_pair.json"
export MASTER_ADDR="localhost"
export MASTER_PORT="20010"
export WANDB_DISABLED=true
wandb offline

torchrun --nproc_per_node=8 --master_port=20001 toolbench/train/train_long_seq_RRHF.py \
    --model_name_or_path $MODEL_PATH \
    --data_path $DATA_PATH \
    --bf16 True \
    --output_dir $SAVE_PATH \
    --num_train_epochs 3 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 100 \
    --save_total_limit 3 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --gradient_checkpointing True \
    --tf32 True --model_max_length 8192 --rrhf_weight 1

Inference, Generate Solution Tree

Script

bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh <GPU_Id> <model_name> <method_name> <decode_method> <output_path> <test_sample> <retriever_path> <TOOLBENCH_KEY>

ToolBench Key

Go to ToolBench to apply for a ToolBench Key.

Decode_Method

Model Method
Full Model Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2
Seperate Retriever Mix_Whole3Tag_MixWhole3TagTrace_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2
Without Solution Planning Mix_Whole3Tag_MixWhole3TagTrace_MixWhole3Retri_MixWhole3Gen_DFS_woFilter_w2
Without Tag Extraction Mix_Whole3Tag_MixWhole3TagTrace_MixTagTraceRetri_MixTagTraceGen_DFS_woFilter_w2
Without Tag & Solution Mix_Whole3Tag_MixWhole3TagTrace_MixRetri_MixGen_DFS_woFilter_w2
Chain-based Method Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_CoT@5

Example

bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 6,7 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_Desc_1122_level_23 data/test_sample/G3_query_100_opendomain.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY

bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 1,3 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_Cate_1122_level_23 data/test_sample/G3_query_100_level_cate.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY
bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 2,4 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_Tool_1122_level_23 data/test_sample/G3_query_100_level_tool.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY
bash scripts/category/inference/inference_cuda_model_method_output_input_tag.sh 5,4 ToolPlanner_Stage2_1107 Mix_Whole3Tag_MixWhole3TagTrace_3TagRepla_PureRepla_MixWhole3Retri_MixWhole3TagTraceGen_DFS_woFilter_w2 data/category/inference/plan_1107_G3_gensample_RRHF_API_1122_level_23 data/test_sample/G3_query_100_level_api.json model_1122_G3_tag_trace_multilevel TOOLBENCH_KEY

Eval

Script

Use generated results to eval Match Rate and Pass Rate

bash scripts/category/eval/eval_match_pass_rate.sh api name2 <output_path>

Example

bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_Cate_1122_level_23
bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_Tool_1122_level_23
bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_API_1122_level_23
bash scripts/category/eval/eval_match_pass_rate.sh api name2 data/category/inference/plan_1107_G3_gensample_RRHF_Desc_1122_level_23

Script

Use generated results to eval Win Rate

Change generate(prompt, name) function in "ToolPlanner/toolbench/tooleval/new_eval_win_rate_cut_list.py" to your own ChatGPT API.

bash scripts/category/eval/eval_match_pass_rate.sh api name2 <output_path>

Example

bash scripts/inference/convert_preprocess_win_rate.sh DFS data/category/inference/plan_1107_G3_gensample_RRHF_Cate_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_Cate_1122_level_23.json data/category/inference/plan_1107_G3_gensample_RRHF_Tool_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_Tool_1122_level_23.json data/category/inference/plan_1107_G3_gensample_RRHF_API_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_API_1122_level_23.json data/category/inference/plan_1107_G3_gensample_RRHF_Desc_1122_level_23 data/category/converted_answer/plan_1107_G3_gensample_RRHF_Desc_1122_level_23.json
bash scripts/inference/eval_win_rate_cut_list.sh data/category/converted_answer/plan_1107_G3_gensample_RRHF_Cate_1122_level_23.json

Citation

@misc{wu2024toolplannertoolaugmentedllm,
      title={ToolPlanner: A Tool Augmented LLM for Multi Granularity Instructions with Path Planning and Feedback}, 
      author={Qinzhuo Wu and Wei Liu and Jian Luan and Bin Wang},
      year={2024},
      eprint={2409.14826},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2409.14826}, 
}

License

The dataset of this project is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0) license.

The source code of the this is licensed under the Apache 2.0 license.

Summary of Terms

  • Attribution: You must give appropriate credit, provide a link to the license, and indicate if changes were made.
  • NonCommercial: You may not use the material for commercial purposes.
  • ShareAlike: If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.

License Badge

License: CC BY-NC-SA 4.0

Copyright

Copyright (C) 2024 Xiaomi Corporation. The dataset included in this project is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0) license. The source code included in this project is licensed under the Apache 2.0 license.

5. Citation

If you'd like to use our benchmark or cite this paper, please kindly use the reference below:

@inproceedings{wu2024toolplanner,
  title={ToolPlanner: A Tool Augmented LLM for Multi Granularity Instructions with Path Planning and Feedback},
  author={Wu, Qinzhuo and Liu, Wei and Luan, Jian and Wang, Bin},
  booktitle={Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing},
  pages={18315--18339},
  year={2024}
}
Downloads last month
11
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.