Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/include
/ATen
/native
/FusedAdagrad.h
namespace at { | |
namespace native { | |
using fused_adagrad_fn = void (*)( | |
const at::Tensor& param, | |
const at::Tensor& grad, | |
const at::Tensor& state_sum, | |
const at::Tensor& state_step, | |
const double lr, | |
const double lr_decay, | |
const double weight_decay, | |
const double eps, | |
const bool maximize, | |
const float* grad_scale_ptr); | |
DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub); | |
} | |
} | |