Issue while finetuning embedding model because of use_reentrant = True
Hey,
as talked in the other thread (which I closed) you recommended using gradient_checkpointing on "smaller hardware". I tried that before and it failed with several errors, the first one was:
venv/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
Later on the script failed finally with:
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 387 with name model.encoder.layer.23.output.LayerNorm.weight has been marked as ready twice. This means that multiple autograd engine hooks have fired for this particular parameter during this iteration.
My execution looks like this - default code:
#!/bin/bash
NCCL_P2P_DISABLE="0" NCCL_IB_DISABLE="0" ./venv/bin/torchrun --nproc_per_node 2 \
-m FlagEmbedding.baai_general_embedding.finetune.run \
--output_dir "./embedding-models/bge-m3-fda-full" \
--model_name_or_path "BAAI/bge-m3" \
--train_data "./training-data/good-for-fda.jsonl" \
--learning_rate 5e-6 \
--num_train_epochs 4 \
--per_device_train_batch_size 6 \
--dataloader_drop_last True \
--bf16 True \
--normlized True \
--warmup_steps 0 \
--gradient_checkpointing True \
--temperature 0.01 \
--query_max_len 1024 \
--seed 13317 \
--passage_max_len 1280 \
--train_group_size 6 \
--save_strategy "epoch" \
--save_total_limit 2 \
--negatives_cross_device \
--logging_steps 25 \
--query_instruction_for_retrieval ""
This error only happens when --gradient_checkpointing True
is set.
So obviously I patched the relevant code and set --use_reentrent=False
. Now it works.
Nevertheless: Maybe you should know that and allow setting that flag via CLI (afaik saw its not possible).
Best,
Damian
Sorry for the late reply, our holiday has just ended.
Directly enabling gradient_checkpointing will lead to issues. We resolve this problem by using DeepSpeed, which means DeepSpeed must be used when gradient_checkpointing is enabled.
"""
NCCL_P2P_DISABLE="0" NCCL_IB_DISABLE="0" ./venv/bin/torchrun --nproc_per_node 2
-m FlagEmbedding.baai_general_embedding.finetune.run
--output_dir "./embedding-models/bge-m3-fda-full"
--model_name_or_path "BAAI/bge-m3"
--train_data "./training-data/good-for-fda.jsonl"
--learning_rate 5e-6
--num_train_epochs 4
--per_device_train_batch_size 6
--dataloader_drop_last True
--bf16 True
--normlized True
--warmup_steps 0
--gradient_checkpointing True
--temperature 0.01
--query_max_len 1024
--seed 13317
--passage_max_len 1280
--train_group_size 6
--save_strategy "epoch"
--save_total_limit 2
--negatives_cross_device
--logging_steps 25
--query_instruction_for_retrieval ""
--deepspeed "./ds_config.json"
"""
You can use the ds_config.json we provided: https://github.com/FlagOpen/FlagEmbedding/blob/master/examples/finetune/ds_config.json.
If you have alternative solutions, welcome to submit a PR.
Ah,
the fact that deepspeed has to be used wasnt documented! Thanks for that info.
If I have some spare time I am happy to contribute and allow an additional flag "--deactivate_reentrant" which would also solve the issue and no patching is needed.