DrFetWartz's picture
Upload folder using huggingface_hub
ffaa9fc
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import io
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
from accelerate.data_loader import prepare_data_loader
from accelerate.state import AcceleratorState
from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors
from accelerate.utils import (
DistributedType,
gather,
is_bf16_available,
is_torch_version,
set_seed,
synchronize_rng_states,
)
def print_main(state):
print(f"Printing from the main process {state.process_index}")
def print_local_main(state):
print(f"Printing from the local main process {state.local_process_index}")
def print_last(state):
print(f"Printing from the last process {state.process_index}")
def print_on(state, process_idx):
print(f"Printing from process {process_idx}: {state.process_index}")
def process_execution_check():
accelerator = Accelerator()
num_processes = accelerator.num_processes
with accelerator.main_process_first():
idx = torch.tensor(accelerator.process_index).to(accelerator.device)
idxs = accelerator.gather(idx)
if num_processes > 1:
assert idxs[0] == 0, "Main process was not first."
# Test the decorators
f = io.StringIO()
with contextlib.redirect_stdout(f):
accelerator.on_main_process(print_main)(accelerator.state)
result = f.getvalue().rstrip()
if accelerator.is_main_process:
assert result == "Printing from the main process 0", f"{result} != Printing from the main process 0"
else:
assert f.getvalue().rstrip() == "", f'{result} != ""'
f.truncate(0)
f.seek(0)
with contextlib.redirect_stdout(f):
accelerator.on_local_main_process(print_local_main)(accelerator.state)
if accelerator.is_local_main_process:
assert f.getvalue().rstrip() == "Printing from the local main process 0"
else:
assert f.getvalue().rstrip() == ""
f.truncate(0)
f.seek(0)
with contextlib.redirect_stdout(f):
accelerator.on_last_process(print_last)(accelerator.state)
if accelerator.is_last_process:
assert f.getvalue().rstrip() == f"Printing from the last process {accelerator.state.num_processes - 1}"
else:
assert f.getvalue().rstrip() == ""
f.truncate(0)
f.seek(0)
for process_idx in range(num_processes):
with contextlib.redirect_stdout(f):
accelerator.on_process(print_on, process_index=process_idx)(accelerator.state, process_idx)
if accelerator.process_index == process_idx:
assert f.getvalue().rstrip() == f"Printing from process {process_idx}: {accelerator.process_index}"
else:
assert f.getvalue().rstrip() == ""
f.truncate(0)
f.seek(0)
def init_state_check():
# Test we can instantiate this twice in a row.
state = AcceleratorState()
if state.local_process_index == 0:
print("Testing, testing. 1, 2, 3.")
print(state)
def rng_sync_check():
state = AcceleratorState()
synchronize_rng_states(["torch"])
assert are_the_same_tensors(torch.get_rng_state()), "RNG states improperly synchronized on CPU."
if state.distributed_type == DistributedType.MULTI_GPU:
synchronize_rng_states(["cuda"])
assert are_the_same_tensors(torch.cuda.get_rng_state()), "RNG states improperly synchronized on GPU."
generator = torch.Generator()
synchronize_rng_states(["generator"], generator=generator)
assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator."
if state.local_process_index == 0:
print("All rng are properly synched.")
def dl_preparation_check():
state = AcceleratorState()
length = 32 * state.num_processes
dl = DataLoader(range(length), batch_size=8)
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
result = []
for batch in dl:
result.append(gather(batch))
result = torch.cat(result)
print(state.process_index, result, type(dl))
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
dl = DataLoader(range(length), batch_size=8)
dl = prepare_data_loader(
dl,
state.device,
state.num_processes,
state.process_index,
put_on_device=True,
split_batches=True,
)
result = []
for batch in dl:
result.append(gather(batch))
result = torch.cat(result)
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
if state.process_index == 0:
print("Non-shuffled dataloader passing.")
dl = DataLoader(range(length), batch_size=8, shuffle=True)
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
result = []
for batch in dl:
result.append(gather(batch))
result = torch.cat(result).tolist()
result.sort()
assert result == list(range(length)), "Wrong shuffled dataloader result."
dl = DataLoader(range(length), batch_size=8, shuffle=True)
dl = prepare_data_loader(
dl,
state.device,
state.num_processes,
state.process_index,
put_on_device=True,
split_batches=True,
)
result = []
for batch in dl:
result.append(gather(batch))
result = torch.cat(result).tolist()
result.sort()
assert result == list(range(length)), "Wrong shuffled dataloader result."
if state.local_process_index == 0:
print("Shuffled dataloader passing.")
def central_dl_preparation_check():
state = AcceleratorState()
length = 32 * state.num_processes
dl = DataLoader(range(length), batch_size=8)
dl = prepare_data_loader(
dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True
)
result = []
for batch in dl:
result.append(gather(batch))
result = torch.cat(result)
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
dl = DataLoader(range(length), batch_size=8)
dl = prepare_data_loader(
dl,
state.device,
state.num_processes,
state.process_index,
put_on_device=True,
split_batches=True,
dispatch_batches=True,
)
result = []
for batch in dl:
result.append(gather(batch))
result = torch.cat(result)
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
if state.process_index == 0:
print("Non-shuffled central dataloader passing.")
dl = DataLoader(range(length), batch_size=8, shuffle=True)
dl = prepare_data_loader(
dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True
)
result = []
for batch in dl:
result.append(gather(batch))
result = torch.cat(result).tolist()
result.sort()
assert result == list(range(length)), "Wrong shuffled dataloader result."
dl = DataLoader(range(length), batch_size=8, shuffle=True)
dl = prepare_data_loader(
dl,
state.device,
state.num_processes,
state.process_index,
put_on_device=True,
split_batches=True,
dispatch_batches=True,
)
result = []
for batch in dl:
result.append(gather(batch))
result = torch.cat(result).tolist()
result.sort()
assert result == list(range(length)), "Wrong shuffled dataloader result."
if state.local_process_index == 0:
print("Shuffled central dataloader passing.")
def mock_training(length, batch_size, generator):
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
loss.backward()
optimizer.step()
return train_set, model
def training_check():
state = AcceleratorState()
generator = torch.Generator()
batch_size = 8
length = batch_size * 4 * state.num_processes
train_set, old_model = mock_training(length, batch_size * state.num_processes, generator)
assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes."
assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes."
accelerator = Accelerator()
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for epoch in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()
model = accelerator.unwrap_model(model).cpu()
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."
accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")
accelerator = Accelerator(split_batches=True)
train_dl = DataLoader(train_set, batch_size=batch_size * state.num_processes, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()
model = accelerator.unwrap_model(model).cpu()
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."
accelerator.print("Training yielded the same results on one CPU or distributes setup with batch split.")
if torch.cuda.is_available():
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
print("FP16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="fp16")
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()
model = accelerator.unwrap_model(model).cpu()
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."
# BF16 support is only for CPU + TPU, and some GPU
if is_bf16_available():
# Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16
print("BF16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="bf16")
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
accelerator.backward(loss)
optimizer.step()
model = accelerator.unwrap_model(model).cpu()
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."
def main():
accelerator = Accelerator()
state = accelerator.state
if state.local_process_index == 0:
print("**Initialization**")
init_state_check()
if state.local_process_index == 0:
print("\n**Test process execution**")
process_execution_check()
if state.local_process_index == 0:
print("\n**Test random number generator synchronization**")
rng_sync_check()
if state.local_process_index == 0:
print("\n**DataLoader integration test**")
dl_preparation_check()
if state.distributed_type != DistributedType.TPU and is_torch_version(">=", "1.8.0"):
central_dl_preparation_check()
# Trainings are not exactly the same in DeepSpeed and CPU mode
if state.distributed_type == DistributedType.DEEPSPEED:
return
if state.local_process_index == 0:
print("\n**Training integration test**")
training_check()
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()