pengdadaaa's picture
Upload 741 files
786f6a6 verified
raw
history blame
24.7 kB
""" Optimzier Tests
These tests were adapted from PyTorch' optimizer tests.
"""
import math
import pytest
import functools
from copy import deepcopy
import torch
from torch.testing._internal.common_utils import TestCase
from torch.nn import Parameter
from timm.scheduler import PlateauLRScheduler
from timm.optim import create_optimizer_v2
import importlib
import os
torch_backend = os.environ.get('TORCH_BACKEND')
if torch_backend is not None:
importlib.import_module(torch_backend)
torch_device = os.environ.get('TORCH_DEVICE', 'cuda')
# HACK relying on internal PyTorch test functionality for comparisons that I don't want to write
torch_tc = TestCase()
def _test_basic_cases_template(weight, bias, input, constructor, scheduler_constructors):
weight = Parameter(weight)
bias = Parameter(bias)
input = Parameter(input)
optimizer = constructor(weight, bias)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(200):
for scheduler in schedulers:
if isinstance(scheduler, PlateauLRScheduler):
val_loss = fn()
scheduler.step(val_loss)
else:
scheduler.step()
optimizer.step(fn)
assert fn().item() < initial_value
def _test_state_dict(weight, bias, input, constructor):
weight = Parameter(weight)
bias = Parameter(bias)
input = Parameter(input)
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
i = input_device if weight.device.type != 'cpu' else input
loss = (weight.mv(i) + bias).pow(2).sum()
loss.backward()
return loss
optimizer = constructor(weight, bias)
fn = functools.partial(fn_base, optimizer, weight, bias)
# Prime the optimizer
for _i in range(20):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
with torch.no_grad():
weight_c = Parameter(weight.clone().detach())
bias_c = Parameter(bias.clone().detach())
optimizer_c = constructor(weight_c, bias_c)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizations in parallel
for _i in range(20):
optimizer.step(fn)
optimizer_c.step(fn_c)
torch_tc.assertEqual(weight, weight_c)
torch_tc.assertEqual(bias, bias_c)
# Make sure state dict is deterministic with equal but not identical parameters
torch_tc.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
# Make sure repeated parameters have identical representation in state dict
optimizer_c.param_groups.extend(optimizer_c.param_groups)
torch_tc.assertEqual(optimizer.state_dict()['param_groups'][-1], optimizer_c.state_dict()['param_groups'][-1])
# Check that state dict can be loaded even when we cast parameters
# to a different type and move to a different device.
if torch_device == 'cpu':
return
elif torch_device == 'cuda' and not torch.cuda.is_available():
return
with torch.no_grad():
input_device = Parameter(input.clone().detach().float().to(torch_device))
weight_device = Parameter(weight.clone().detach().to(torch_device))
bias_device = Parameter(bias.clone().detach().to(torch_device))
optimizer_device = constructor(weight_device, bias_device)
fn_device = functools.partial(fn_base, optimizer_device, weight_device, bias_device)
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_device.load_state_dict(state_dict_c)
# Make sure state dict wasn't modified
torch_tc.assertEqual(state_dict, state_dict_c)
for _i in range(20):
optimizer.step(fn)
optimizer_device.step(fn_device)
torch_tc.assertEqual(weight, weight_device)
torch_tc.assertEqual(bias, bias_device)
# validate deepcopy() copies all public attributes
def getPublicAttr(obj):
return set(k for k in obj.__dict__ if not k.startswith('_'))
assert getPublicAttr(optimizer) == getPublicAttr(deepcopy(optimizer))
def _test_basic_cases(constructor, scheduler_constructors=None):
if scheduler_constructors is None:
scheduler_constructors = []
_test_state_dict(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
constructor
)
_test_basic_cases_template(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
constructor,
scheduler_constructors
)
# non-contiguous parameters
_test_basic_cases_template(
torch.randn(10, 5, 2)[..., 0],
torch.randn(10, 2)[..., 0],
torch.randn(5),
constructor,
scheduler_constructors
)
# CUDA
if torch_device == 'cpu':
return
elif torch_device == 'cuda' and not torch.cuda.is_available():
return
_test_basic_cases_template(
torch.randn(10, 5).to(torch_device),
torch.randn(10).to(torch_device),
torch.randn(5).to(torch_device),
constructor,
scheduler_constructors
)
def _test_model(optimizer, params, device=torch.device('cpu')):
weight = torch.tensor(
[[-0.2109, -0.4976], [-0.1413, -0.3420], [-0.2524, 0.6976]],
device=device, requires_grad=True)
bias = torch.tensor([-0.1085, -0.2979, 0.6892], device=device, requires_grad=True)
weight2 = torch.tensor([[-0.0508, -0.3941, -0.2843]], device=device, requires_grad=True)
bias2 = torch.tensor([-0.0711], device=device, requires_grad=True)
input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device=device).reshape(3, 2)
model = torch.nn.Sequential(torch.nn.Linear(2, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid())
model.to(device)
pretrained_dict = model.state_dict()
pretrained_dict['0.weight'] = weight
pretrained_dict['0.bias'] = bias
pretrained_dict['2.weight'] = weight2
pretrained_dict['2.bias'] = bias2
model.load_state_dict(pretrained_dict)
optimizer = create_optimizer_v2(model, opt=optimizer, **params)
prev_loss = float('inf')
for i in range(20):
optimizer.zero_grad()
output = model(input)
loss = output.sum()
loss.backward()
loss = loss.item()
assert loss < prev_loss
prev_loss = loss
optimizer.step()
def rosenbrock(tensor):
x, y = tensor
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
def drosenbrock(tensor):
x, y = tensor
return torch.tensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2)))
def _test_rosenbrock(constructor, scheduler_constructors=None):
if scheduler_constructors is None:
scheduler_constructors = []
params_t = torch.tensor([1.5, 1.5])
params = Parameter(params_t)
optimizer = constructor([params])
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
solution = torch.tensor([1, 1])
initial_dist = params.clone().detach().dist(solution)
def eval(params, w):
# Depending on w, provide only the x or y gradient
optimizer.zero_grad()
loss = rosenbrock(params)
loss.backward()
grad = drosenbrock(params.clone().detach())
# NB: We torture test the optimizer by returning an
# uncoalesced sparse tensor
if w:
i = torch.LongTensor([[0, 0]])
x = grad[0]
v = torch.tensor([x / 4., x - x / 4.])
else:
i = torch.LongTensor([[1, 1]])
y = grad[1]
v = torch.tensor([y - y / 4., y / 4.])
x = torch.sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype)
with torch.no_grad():
params.grad = x.to_dense()
return loss
for i in range(2000):
# Do cyclic coordinate descent
w = i % 2
optimizer.step(functools.partial(eval, params, w))
for scheduler in schedulers:
if isinstance(scheduler, PlateauLRScheduler):
scheduler.step(rosenbrock(params))
else:
scheduler.step()
torch_tc.assertLessEqual(params.clone().detach().dist(solution), initial_dist)
def _build_params_dict(weight, bias, **kwargs):
return [{'params': [weight]}, dict(params=[bias], **kwargs)]
def _build_params_dict_single(weight, bias, **kwargs):
return [dict(params=bias, **kwargs)]
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
# FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts
@pytest.mark.parametrize('optimizer', ['sgd'])
def test_sgd(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
)
# _test_basic_cases(
# lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10)]
# )
# _test_basic_cases(
# lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
# [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")]
# )
# _test_basic_cases(
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
# [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")]
# )
# _test_basic_cases(
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
# lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4)]
# )
# _test_basic_cases(
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
# lambda opt: ReduceLROnPlateau(opt)]
# )
# _test_basic_cases(
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
# [lambda opt: StepLR(opt, gamma=0.99, step_size=10),
# lambda opt: ExponentialLR(opt, gamma=0.99),
# lambda opt: ReduceLROnPlateau(opt)]
# )
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1, weight_decay=.1)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax'])
def test_adam(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
)
_test_model(optimizer, dict(lr=5e-2))
@pytest.mark.parametrize('optimizer', ['adabelief'])
def test_adabelief(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
)
_test_model(optimizer, dict(lr=5e-2))
@pytest.mark.parametrize('optimizer', ['radam', 'radabelief'])
def test_rectified(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['adadelta', 'adagrad'])
def test_adaother(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-1)
)
_test_model(optimizer, dict(lr=5e-2))
@pytest.mark.parametrize('optimizer', ['adafactor'])
def test_adafactor(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(_build_params_dict_single(weight, bias), optimizer)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
)
_test_model(optimizer, dict(lr=5e-2))
@pytest.mark.parametrize('optimizer', ['lamb', 'lambc'])
def test_lamb(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
def test_lars(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
def test_madgrad(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
)
_test_model(optimizer, dict(lr=1e-2))
@pytest.mark.parametrize('optimizer', ['novograd'])
def test_novograd(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf'])
def test_rmsprop(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
)
_test_model(optimizer, dict(lr=1e-2))
@pytest.mark.parametrize('optimizer', ['adamp'])
def test_adamp(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
)
_test_model(optimizer, dict(lr=5e-2))
@pytest.mark.parametrize('optimizer', ['sgdp'])
def test_sgdp(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum'])
def test_lookahead_sgd(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
@pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam'])
def test_lookahead_adam(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
)
@pytest.mark.parametrize('optimizer', ['lookahead_radam'])
def test_lookahead_radam(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
)