Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import contextlib | |
import io | |
import torch | |
from torch.utils.data import DataLoader | |
from accelerate import Accelerator | |
from accelerate.data_loader import prepare_data_loader | |
from accelerate.state import AcceleratorState | |
from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors | |
from accelerate.utils import ( | |
DistributedType, | |
gather, | |
is_bf16_available, | |
is_torch_version, | |
set_seed, | |
synchronize_rng_states, | |
) | |
def print_main(state): | |
print(f"Printing from the main process {state.process_index}") | |
def print_local_main(state): | |
print(f"Printing from the local main process {state.local_process_index}") | |
def print_last(state): | |
print(f"Printing from the last process {state.process_index}") | |
def print_on(state, process_idx): | |
print(f"Printing from process {process_idx}: {state.process_index}") | |
def process_execution_check(): | |
accelerator = Accelerator() | |
num_processes = accelerator.num_processes | |
with accelerator.main_process_first(): | |
idx = torch.tensor(accelerator.process_index).to(accelerator.device) | |
idxs = accelerator.gather(idx) | |
if num_processes > 1: | |
assert idxs[0] == 0, "Main process was not first." | |
# Test the decorators | |
f = io.StringIO() | |
with contextlib.redirect_stdout(f): | |
accelerator.on_main_process(print_main)(accelerator.state) | |
result = f.getvalue().rstrip() | |
if accelerator.is_main_process: | |
assert result == "Printing from the main process 0", f"{result} != Printing from the main process 0" | |
else: | |
assert f.getvalue().rstrip() == "", f'{result} != ""' | |
f.truncate(0) | |
f.seek(0) | |
with contextlib.redirect_stdout(f): | |
accelerator.on_local_main_process(print_local_main)(accelerator.state) | |
if accelerator.is_local_main_process: | |
assert f.getvalue().rstrip() == "Printing from the local main process 0" | |
else: | |
assert f.getvalue().rstrip() == "" | |
f.truncate(0) | |
f.seek(0) | |
with contextlib.redirect_stdout(f): | |
accelerator.on_last_process(print_last)(accelerator.state) | |
if accelerator.is_last_process: | |
assert f.getvalue().rstrip() == f"Printing from the last process {accelerator.state.num_processes - 1}" | |
else: | |
assert f.getvalue().rstrip() == "" | |
f.truncate(0) | |
f.seek(0) | |
for process_idx in range(num_processes): | |
with contextlib.redirect_stdout(f): | |
accelerator.on_process(print_on, process_index=process_idx)(accelerator.state, process_idx) | |
if accelerator.process_index == process_idx: | |
assert f.getvalue().rstrip() == f"Printing from process {process_idx}: {accelerator.process_index}" | |
else: | |
assert f.getvalue().rstrip() == "" | |
f.truncate(0) | |
f.seek(0) | |
def init_state_check(): | |
# Test we can instantiate this twice in a row. | |
state = AcceleratorState() | |
if state.local_process_index == 0: | |
print("Testing, testing. 1, 2, 3.") | |
print(state) | |
def rng_sync_check(): | |
state = AcceleratorState() | |
synchronize_rng_states(["torch"]) | |
assert are_the_same_tensors(torch.get_rng_state()), "RNG states improperly synchronized on CPU." | |
if state.distributed_type == DistributedType.MULTI_GPU: | |
synchronize_rng_states(["cuda"]) | |
assert are_the_same_tensors(torch.cuda.get_rng_state()), "RNG states improperly synchronized on GPU." | |
generator = torch.Generator() | |
synchronize_rng_states(["generator"], generator=generator) | |
assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator." | |
if state.local_process_index == 0: | |
print("All rng are properly synched.") | |
def dl_preparation_check(): | |
state = AcceleratorState() | |
length = 32 * state.num_processes | |
dl = DataLoader(range(length), batch_size=8) | |
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True) | |
result = [] | |
for batch in dl: | |
result.append(gather(batch)) | |
result = torch.cat(result) | |
print(state.process_index, result, type(dl)) | |
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." | |
dl = DataLoader(range(length), batch_size=8) | |
dl = prepare_data_loader( | |
dl, | |
state.device, | |
state.num_processes, | |
state.process_index, | |
put_on_device=True, | |
split_batches=True, | |
) | |
result = [] | |
for batch in dl: | |
result.append(gather(batch)) | |
result = torch.cat(result) | |
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." | |
if state.process_index == 0: | |
print("Non-shuffled dataloader passing.") | |
dl = DataLoader(range(length), batch_size=8, shuffle=True) | |
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True) | |
result = [] | |
for batch in dl: | |
result.append(gather(batch)) | |
result = torch.cat(result).tolist() | |
result.sort() | |
assert result == list(range(length)), "Wrong shuffled dataloader result." | |
dl = DataLoader(range(length), batch_size=8, shuffle=True) | |
dl = prepare_data_loader( | |
dl, | |
state.device, | |
state.num_processes, | |
state.process_index, | |
put_on_device=True, | |
split_batches=True, | |
) | |
result = [] | |
for batch in dl: | |
result.append(gather(batch)) | |
result = torch.cat(result).tolist() | |
result.sort() | |
assert result == list(range(length)), "Wrong shuffled dataloader result." | |
if state.local_process_index == 0: | |
print("Shuffled dataloader passing.") | |
def central_dl_preparation_check(): | |
state = AcceleratorState() | |
length = 32 * state.num_processes | |
dl = DataLoader(range(length), batch_size=8) | |
dl = prepare_data_loader( | |
dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True | |
) | |
result = [] | |
for batch in dl: | |
result.append(gather(batch)) | |
result = torch.cat(result) | |
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." | |
dl = DataLoader(range(length), batch_size=8) | |
dl = prepare_data_loader( | |
dl, | |
state.device, | |
state.num_processes, | |
state.process_index, | |
put_on_device=True, | |
split_batches=True, | |
dispatch_batches=True, | |
) | |
result = [] | |
for batch in dl: | |
result.append(gather(batch)) | |
result = torch.cat(result) | |
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." | |
if state.process_index == 0: | |
print("Non-shuffled central dataloader passing.") | |
dl = DataLoader(range(length), batch_size=8, shuffle=True) | |
dl = prepare_data_loader( | |
dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True | |
) | |
result = [] | |
for batch in dl: | |
result.append(gather(batch)) | |
result = torch.cat(result).tolist() | |
result.sort() | |
assert result == list(range(length)), "Wrong shuffled dataloader result." | |
dl = DataLoader(range(length), batch_size=8, shuffle=True) | |
dl = prepare_data_loader( | |
dl, | |
state.device, | |
state.num_processes, | |
state.process_index, | |
put_on_device=True, | |
split_batches=True, | |
dispatch_batches=True, | |
) | |
result = [] | |
for batch in dl: | |
result.append(gather(batch)) | |
result = torch.cat(result).tolist() | |
result.sort() | |
assert result == list(range(length)), "Wrong shuffled dataloader result." | |
if state.local_process_index == 0: | |
print("Shuffled central dataloader passing.") | |
def mock_training(length, batch_size, generator): | |
set_seed(42) | |
generator.manual_seed(42) | |
train_set = RegressionDataset(length=length) | |
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | |
model = RegressionModel() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
for epoch in range(3): | |
for batch in train_dl: | |
model.zero_grad() | |
output = model(batch["x"]) | |
loss = torch.nn.functional.mse_loss(output, batch["y"]) | |
loss.backward() | |
optimizer.step() | |
return train_set, model | |
def training_check(): | |
state = AcceleratorState() | |
generator = torch.Generator() | |
batch_size = 8 | |
length = batch_size * 4 * state.num_processes | |
train_set, old_model = mock_training(length, batch_size * state.num_processes, generator) | |
assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes." | |
assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes." | |
accelerator = Accelerator() | |
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | |
model = RegressionModel() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | |
set_seed(42) | |
generator.manual_seed(42) | |
for epoch in range(3): | |
for batch in train_dl: | |
model.zero_grad() | |
output = model(batch["x"]) | |
loss = torch.nn.functional.mse_loss(output, batch["y"]) | |
accelerator.backward(loss) | |
optimizer.step() | |
model = accelerator.unwrap_model(model).cpu() | |
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | |
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | |
accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.") | |
accelerator = Accelerator(split_batches=True) | |
train_dl = DataLoader(train_set, batch_size=batch_size * state.num_processes, shuffle=True, generator=generator) | |
model = RegressionModel() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | |
set_seed(42) | |
generator.manual_seed(42) | |
for _ in range(3): | |
for batch in train_dl: | |
model.zero_grad() | |
output = model(batch["x"]) | |
loss = torch.nn.functional.mse_loss(output, batch["y"]) | |
accelerator.backward(loss) | |
optimizer.step() | |
model = accelerator.unwrap_model(model).cpu() | |
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | |
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | |
accelerator.print("Training yielded the same results on one CPU or distributes setup with batch split.") | |
if torch.cuda.is_available(): | |
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16 | |
print("FP16 training check.") | |
AcceleratorState._reset_state() | |
accelerator = Accelerator(mixed_precision="fp16") | |
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | |
model = RegressionModel() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | |
set_seed(42) | |
generator.manual_seed(42) | |
for _ in range(3): | |
for batch in train_dl: | |
model.zero_grad() | |
output = model(batch["x"]) | |
loss = torch.nn.functional.mse_loss(output, batch["y"]) | |
accelerator.backward(loss) | |
optimizer.step() | |
model = accelerator.unwrap_model(model).cpu() | |
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | |
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | |
# BF16 support is only for CPU + TPU, and some GPU | |
if is_bf16_available(): | |
# Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16 | |
print("BF16 training check.") | |
AcceleratorState._reset_state() | |
accelerator = Accelerator(mixed_precision="bf16") | |
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) | |
model = RegressionModel() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) | |
set_seed(42) | |
generator.manual_seed(42) | |
for _ in range(3): | |
for batch in train_dl: | |
model.zero_grad() | |
output = model(batch["x"]) | |
loss = torch.nn.functional.mse_loss(output, batch["y"]) | |
accelerator.backward(loss) | |
optimizer.step() | |
model = accelerator.unwrap_model(model).cpu() | |
assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." | |
assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." | |
def main(): | |
accelerator = Accelerator() | |
state = accelerator.state | |
if state.local_process_index == 0: | |
print("**Initialization**") | |
init_state_check() | |
if state.local_process_index == 0: | |
print("\n**Test process execution**") | |
process_execution_check() | |
if state.local_process_index == 0: | |
print("\n**Test random number generator synchronization**") | |
rng_sync_check() | |
if state.local_process_index == 0: | |
print("\n**DataLoader integration test**") | |
dl_preparation_check() | |
if state.distributed_type != DistributedType.TPU and is_torch_version(">=", "1.8.0"): | |
central_dl_preparation_check() | |
# Trainings are not exactly the same in DeepSpeed and CPU mode | |
if state.distributed_type == DistributedType.DEEPSPEED: | |
return | |
if state.local_process_index == 0: | |
print("\n**Training integration test**") | |
training_check() | |
def _mp_fn(index): | |
# For xla_spawn (TPUs) | |
main() | |
if __name__ == "__main__": | |
main() | |