Issues running example_fsdp.py

#85
by TheFishInTheAir - opened

I've been trying to get the provided FSDP example working on a TPU v5e-8 vm. Im running into what I believe are some version conflicts and I was wondering if it was possible for someone to detail the environment that the test was originally run in (specifically what versions of what libraries where used).

For FSDP it seems like Transformers uses features in torch_xla that aren't available in the current stable release (2.2.0 at the time of writing).

AttributeError: module 'torch_xla.distributed.spmd' has no attribute 'set_global_mesh'

set_global_mesh is available in version 2.3.0 and on of torch_xla but I haven't had any luck getting those versions or the nightly builds working either.
With newer unstable versions the following check is failing:

    XLA_CHECK(dim1 == dim2 || dim1 == 1 || dim2 == 1 ||
              dim1 == xla::Shape::kUnboundedSize ||
              dim2 == xla::Shape::kUnboundedSize);

Of course the errors themselves are not related to gemma but pytorch_xla and Transformer. I have provided them for clarity on what I am experiencing.
If someone could explain how to set up a working environment for the example script it would be much appreciated! Thankyou!

Current versions installed for reference:

torch                    2.2.2
torch-xla                2.2.0
libtpu-nightly           0.1.dev20231130+default
transformers             4.39.3
trl                      0.8.1
accelerate               0.29.1

For now you will need to install torch and torch-xla nightly. However, all the features will be available in the upcoming 2.3 release.

I'm unfortunately still running into issues when using the nightly builds. It seems like there are compatibility issues with new versions of torch-xla and transformers https://github.com/huggingface/transformers/issues/30091. Are there versions which are known to work?

TheFishInTheAir changed discussion status to closed

My colleague and I built an example using Llama2 7B that should work as well: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/official/training/tpuv5e_llama2_pytorch_finetuning_and_serving.ipynb

Here are the versions we used to get around the accelerate issue:
RUN pip install --upgrade pip
RUN pip install transformers==4.38.2 -U
RUN pip install datasets==2.18.0
RUN pip install trl==0.8.1 peft==0.10.0
RUN pip install accelerate==0.28.0
RUN pip install --upgrade google-cloud-storage

Sign up or log in to comment