include_training_kernels.h#
-
namespace ONNX_LIGHT_NAMESPACE
-
namespace onnx_backend_test
-
namespace kernel
-
class Adam#
- #include <include_training_kernels.h>
Reference implementation of
ai.onnx.preview.training::Adam(v1).Computes one iteration of the Adam stochastic gradient based optimization algorithm for an arbitrary number
N >= 1of optimized tensors. For each optimized tensorX_i(with gradientG_i, accumulated gradientV_iand accumulated squared gradientH_i) the kernel evaluates the pseudo-code documented in the ONNX schema:G_reg = norm_coefficient * X + G V_new = alpha * V + (1 - alpha) * G_reg H_new = beta * H + (1 - beta) * G_reg * G_reg H_sqrt = sqrt(H_new) + epsilon R_adjusted = T > 0 ? R * sqrt(1 - beta^T) / (1 - alpha^T) : R X_new = X - R_adjusted * V_new / H_sqrt X_final = (1 - norm_coefficient_post) * X_new
and returns
{X_final_1..N, V_new_1..N, H_new_1..N}. All optimized tensors share the same scalar inputsR(FLOAT) andT(INT64). Only FLOAT tensors are supported.Public Functions
-
inline explicit Adam(const KernelContext &ctx)#
-
std::vector<Tensor> operator()(const Tensor &R, const Tensor &T, const std::vector<Tensor> &Xs, const std::vector<Tensor> &Gs, const std::vector<Tensor> &Vs, const std::vector<Tensor> &Hs, float alpha = 0.9f, float beta = 0.999f, float epsilon = 1e-6f, float norm_coefficient = 0.0f, float norm_coefficient_post = 0.0f) const#
Computes one Adam iteration for
N == Xs.size()optimized tensors and allocates fresh output tensors.Xs,Gs,VsandHsmust all have the same length and pairwise-matching FLOAT shapes.Rmust be a scalar FLOAT tensor andTa scalar INT64 tensor. The trailingalpha,beta,epsilon,norm_coefficientandnorm_coefficient_postparameters mirror the Adam ONNX schema attributes; defaults match the schema defaults.
-
void operator()(const Tensor &R, const Tensor &T, const std::vector<Tensor> &Xs, const std::vector<Tensor> &Gs, const std::vector<Tensor> &Vs, const std::vector<Tensor> &Hs, std::vector<Tensor> &outputs, float alpha = 0.9f, float beta = 0.999f, float epsilon = 1e-6f, float norm_coefficient = 0.0f, float norm_coefficient_post = 0.0f) const#
In-place overload writing into caller-allocated
outputs. The vector must contain exactly3 * Xs.size()tensors in the layout{X_final_1..N, V_new_1..N, H_new_1..N}where each output already matches the FLOAT data type, shape and buffer size of the corresponding optimized tensor.
Public Static Functions
Private Members
-
KernelContext ctx_#
-
inline explicit Adam(const KernelContext &ctx)#
-
class Adam#
-
namespace kernel
-
namespace onnx_backend_test