Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/include
/ATen
/native
/FusedSGD.h
namespace at { | |
namespace native { | |
using fused_sgd_fn = void (*)( | |
const at::Tensor& param, | |
const at::Tensor& grad, | |
const at::Tensor& momentum_buffer, | |
const double weight_decay, | |
const double momentum, | |
const double lr, | |
const double dampening, | |
const bool nesterov, | |
const bool maximize, | |
const bool is_first_step, | |
const float* grad_scale_ptr); | |
DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub); | |
} | |
} | |