shape_training.h#
Shape-inference functions for ONNX operators in the ai.onnx.preview.training (training) family.
-
namespace ONNX_LIGHT_NAMESPACE
-
namespace onnx_optim
-
namespace shapes
-
namespace training#
Functions
-
void ComputeShapeAdam(ShapesContext &ctx, const NodeProto &node)#
Computes the output :cpp:class:
OptimTensorof anAdamnode and stores it inctx.Adam(ai.onnx.preview.training) updatesNoptimised tensors and their accumulated first / second moments. The input list has the shape[R, T, X_1, ..., X_N, G_1, ..., G_N, V_1, ..., V_N, H_1, ..., H_N](soinput_size == 2 + 4 * N) and the output list has the shape[X_1_new, ..., X_N_new, V_1_new, ..., V_N_new, H_1_new, ..., H_N_new](sooutput_size == 3 * N).Each
*_newoutput mirrors the dtype and shape of the correspondingX_i/V_i/H_iinput.- Parameters:
ctx – In/out context. Must already contain entries for every
X_i,V_iandH_iinput read fromnode; on return it also contains an entry for every output ofnode.node – The
AdamNodeProtowhose outputs should be described.node.op_type()must be"Adam",node.input_size()must be2 + 4 * Nfor someN >= 1andnode.output_size()must be3 * N.
- Throws:
Variables
-
constexpr const char *kOnnxPreviewTrainingDomain = "ai.onnx.preview.training"#
Canonical domain string for the
ai.onnx.preview.trainingoperator set.
-
void ComputeShapeAdam(ShapesContext &ctx, const NodeProto &node)#
-
namespace training#
-
namespace shapes
-
namespace onnx_optim