elementwise_helpers.h#

namespace ONNX_LIGHT_NAMESPACE
namespace onnx_backend_test
namespace kernel#
namespace detail#

Functions

BroadcastInfo CheckBinaryBroadcast(const char *op_name, const char *dtype_name, int32_t expected_dtype, const Tensor &x, const Tensor &y)#

Verifies both inputs have expected_dtype and that their shapes are multidirectional-broadcastable per the standard NumPy/ONNX rules. Throws std::invalid_argument otherwise.

BroadcastInfo CheckBinaryBroadcastInOut(const char *op_name, const char *in_dtype_name, int32_t expected_in_dtype, const Tensor &x, const Tensor &y)#

Variant of :cpp:func:CheckBinaryBroadcast for kernels whose input and output dtypes differ (e.g. Greater/Less take numeric inputs and return BOOL outputs). Validates that both inputs have expected_in_dtype and computes the broadcast info; the caller is responsible for validating the output against its own dtype.

void CheckPreallocatedOutput(const char *op_name, const char *dtype_name, int32_t expected_dtype, const std::vector<int64_t> &expected_shape, size_t expected_bytes, const Tensor &output)#

Verifies the caller-supplied preallocated output tensor matches the expected dtype, shape and byte buffer size.

template<typename TIn, typename TOut, typename Op>
void BinaryElementwise(const char *op_name, const char *dtype_name, int32_t expected_dtype, const Tensor &x, const Tensor &y, Tensor &output, Op op)#

In-place element-wise binary kernel driver. Validates inputs + output then invokes op(a, b) -> TOut for each element pair, with full multidirectional broadcasting. TIn and TOut must match the byte layout of the expected_dtype.

template<typename TIn, typename TOut, typename Op>
Tensor BinaryElementwiseAlloc(const char *op_name, const char *dtype_name, int32_t expected_dtype, const Tensor &x, const Tensor &y, Op op)#

Allocating element-wise binary kernel driver. Builds the output tensor with the broadcasted shape and expected_dtype, then delegates to :cpp:func:BinaryElementwise to fill it in.

template<typename TIn, typename TOut, typename Op>
void BinaryElementwiseInOut(const char *op_name, const char *in_dtype_name, int32_t in_dtype, const char *out_dtype_name, int32_t out_dtype, const Tensor &x, const Tensor &y, Tensor &output, Op op)#

Variant of :cpp:func:BinaryElementwise for kernels whose input and output dtypes differ (e.g. Greater/Less). Validates that both inputs have in_dtype and that the preallocated output has out_dtype and the broadcasted shape, then invokes op(a, b) -> TOut for each element pair with full multidirectional broadcasting.

template<typename TIn, typename TOut, typename Op>
Tensor BinaryElementwiseAllocInOut(const char *op_name, const char *in_dtype_name, int32_t in_dtype, const char *out_dtype_name, int32_t out_dtype, const Tensor &x, const Tensor &y, Op op)#

Allocating variant of :cpp:func:BinaryElementwiseInOut. Builds the output tensor with the broadcasted shape and out_dtype, then delegates to :cpp:func:BinaryElementwiseInOut to fill it in.

struct BroadcastInfo#
#include <elementwise_helpers.h>

Information about a validated binary broadcast: the output shape, total element count, the individual input element counts, and per-input element-strides aligned to the output rank (a stride of 0 marks a broadcast dimension). The rank-aligned shape_x/shape_y are also reported for diagnostics. nx/ny are kept for fast-path detection (equal-shape or scalar broadcasting).

Public Members

std::vector<int64_t> shape#
std::vector<int64_t> shape_x#
std::vector<int64_t> shape_y#
std::vector<int64_t> strides_x#
std::vector<int64_t> strides_y#
int64_t element_count = 0#
int64_t nx = 0#
int64_t ny = 0#