Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/include
/ATen
/native
/FusedAdam.h
namespace at { | |
namespace native { | |
enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; | |
using fused_adam_fn = void (*)( | |
const at::Tensor& param, | |
const at::Tensor& grad, | |
const at::Tensor& exp_avg, | |
const at::Tensor& exp_avg_sq, | |
const at::Tensor& max_exp_avg_sq, | |
const at::Tensor& state_step, | |
const double lr, | |
const double beta1, | |
const double beta2, | |
const double weight_decay, | |
const double eps, | |
const bool amsgrad, | |
const bool maximize, | |
const float* grad_scale_ptr, | |
const ADAM_MODE); | |
DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub); | |
} | |
} | |