shape_inference.h#

Core declarations for operator type-and-shape inference, including onnx::ShapeInferenceOptions, onnx::InferenceContext, and helper routines such as onnx::propagateElemTypeFromInputToOutput().

Defines

fail_type_inference(...)#
fail_shape_inference(...)#
namespace ONNX_LIGHT_NAMESPACE

Typedefs

using Dim = TensorShapeProto::Dimension#
using InferenceFunction = std::function<void(InferenceContext&)>#
using DataPropagationFunction = std::function<void(DataPropagationContext&)>#

Functions

inline void dummyInferenceFunction(InferenceContext&)#
inline void dummyDataPropagationFunction(DataPropagationContext&)#
template<typename T>
inline bool getRepeatedAttribute(InferenceContext &ctx, const std::string &attr_name, std::vector<T> &values)#
inline const AttributeProto &getRequiredAttribute(const InferenceContext &ctx, const std::string &name)#
inline int64_t getRequiredAttributeInt(const InferenceContext &ctx, const std::string &name)#
inline int64_t getAttribute(const InferenceContext &ctx, const std::string &attributeName, int64_t defaultValue)#
inline int64_t getAttribute(const DataPropagationContext &ctx, const std::string &attributeName, int64_t defaultValue)#
inline std::string getAttribute(const InferenceContext &ctx, const std::string &attributeName, const std::string &defaultValue)#
inline TensorShapeProto::Dimension operator*(const TensorShapeProto::Dimension &dim1, const TensorShapeProto::Dimension &dim2)#
template<typename Container>
std::string stringify(const Container &elements)#
std::pair<int, int> getAttributeProtoElemTypeAndLength(const AttributeProto *attr_proto)#
std::pair<int, int> getAttributeElementTypeAndLength(const InferenceContext &ctx, const std::initializer_list<std::string> &attribute_names)#
inline TensorShapeProto::Dimension operator*(const TensorShapeProto::Dimension &dim1, int64_t dim2)#
inline TensorShapeProto::Dimension operator/(const TensorShapeProto::Dimension &dim1, int64_t dim2)#
inline TensorShapeProto::Dimension multiplyDims(const TensorShapeProto &shape, int from, int upto_exclusive)#
inline int32_t getTensorElementType(const TypeProto &type)#
inline void setTensorElementType(int32_t elem_type, TypeProto::ValueCase value_case, TypeProto &type)#
void propagateElemTypeWithValidation(const TypeProto *input_type, TypeProto *output_type)#
void propagateElemTypeFromInputToOutput(InferenceContext &ctx, size_t inputIndex, size_t outputIndex)#
void propagateElemTypeFromTensorInputToOutput(InferenceContext &ctx, size_t inputIndex, size_t outputIndex)#
inline void propagateElemTypeFromDtypeToOutput(InferenceContext &ctx, const int data_type, size_t outputIndex, TypeProto::ValueCase expected_value_case)#
inline void propagateElemTypeFromDtypeToOutput(InferenceContext &ctx, const int data_type, size_t outputIndex)#
inline void propagateElemTypeFromDtypeToOutput(InferenceContext &ctx, const AttributeProto *attr, size_t outputIndex)#
inline bool hasShape(const TypeProto &type)#
template<typename Context>
inline bool hasInputShape(const Context &ctx, size_t n)#
template<typename Context>
inline bool hasNInputShapes(const Context &ctx, size_t n)#
inline const TensorShapeProto &getInputShape(const InferenceContext &ctx, size_t n)#
inline const TensorShapeProto *getOptionalInputShape(const InferenceContext &ctx, size_t n)#
inline void appendSingleDimCopiedFromInputTypeToOutputType(InferenceContext &ctx, size_t inputIndex, size_t outputIndex, size_t fromDimIndex)#
inline void propagateShape(const TypeProto *from_type, TypeProto *to_type)#
inline void propagateShapeFromInputToOutput(InferenceContext &ctx, size_t inputIndex, size_t outputIndex)#
inline void propagateShapeAndTypeFromFirstInput(InferenceContext &ctx)#
inline void updateOutputElemType(InferenceContext &ctx, size_t outputIndex, int32_t elemType, TypeProto::ValueCase expected_type)#
inline void updateOutputElemType(InferenceContext &ctx, size_t outputIndex, int32_t elemType)#
inline void propagateElemTypeFromAttributeToOutput(InferenceContext &ctx, const std::string &attributeName, size_t outputIndex, TypeProto::ValueCase expected_type, TensorProto::DataType default_value = TensorProto::UNDEFINED)#
inline void propagateElemTypeFromAttributeToOutput(InferenceContext &ctx, const std::string &attributeName, size_t outputIndex, TensorProto::DataType default_value = TensorProto::UNDEFINED)#
inline TensorShapeProto *getTensorMutableShape(TypeProto::ValueCase value_case, TypeProto &type)#
inline TensorShapeProto *getOutputShape(InferenceContext &ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType)#
inline void appendDim(TensorShapeProto *shape, int64_t dim_value)#
inline void updateOutputShape(InferenceContext &ctx, size_t outputIndex, const TensorShapeProto &shape, TypeProto::ValueCase default_type = TypeProto::kTensorType)#
inline void updateOutputShape(InferenceContext &ctx, size_t outputIndex, const TensorProto &tensorProto, TypeProto::ValueCase default_type = TypeProto::kTensorType)#
inline void updateOutputShape(InferenceContext &ctx, size_t outputIndex, std::initializer_list<TensorShapeProto::Dimension> dims, TypeProto::ValueCase default_type = TypeProto::kTensorType)#
TensorShapeProto getShapeInput(const InferenceContext &ctx, size_t input_index, bool &found)#
TensorShapeProto getShapeInput(const InferenceContext &ctx, size_t input_index, bool fail_if_negative_value, bool &found)#
inline void propagateShapeFromAttributeToOutput(InferenceContext &ctx, const std::string &attributeName, size_t outputIndex, TypeProto::ValueCase default_type = TypeProto::kTensorType)#
inline void multidirectionalBroadcastShapeInference(const std::vector<const TensorShapeProto*> &shapes, TensorShapeProto &resultShape)#
inline void bidirectionalBroadcastShapeInference(const TensorShapeProto &shapeL, const TensorShapeProto &shapeR, TensorShapeProto &resultShape)#
inline void mergeInDimensionInfo(const TensorShapeProto::Dimension &source_dim, TensorShapeProto::Dimension &target_dim, int dim_index)#
void mergeInShapeInfo(const TensorShapeProto &source_shape, TypeProto::Tensor &target_type)#
void mergeInShapeInfo(const TensorShapeProto &source_shape, TypeProto::SparseTensor &target_type)#
void mergeInShapeInfo(const TypeProto::Tensor &source, TypeProto::Tensor &target)#
void mergeInShapeInfo(const TypeProto::SparseTensor &source, TypeProto::SparseTensor &target)#
inline TypeProto RemoveIthDimensionFromShape(const TypeProto &proto, int removed_dim)#
inline TypeProto RemoveDimensionsFromShape(const TypeProto &proto, int num_dimensions)#
template<class T, class U>
static constexpr T narrow_cast(U &&u) noexcept#
inline void checkInputRank(const InferenceContext &ctx, size_t input_index, int expected_rank)#
inline void checkDimEquality(int64_t value1, int64_t value2)#
inline void unifyDim(const Dim &dim1, const Dim &dim2)#
inline void unifyDim(const Dim &source_dim, Dim &target_dim)#
inline void unifyInputDim(const InferenceContext &ctx, size_t input_index, int dim_index, Dim &dim)#
inline void unifyDim(Dim &dim, int64_t value)#
void UnionShapeInfo(const TensorShapeProto &source_shape, TypeProto::Tensor &target_type)#
void UnionShapeInfo(const TensorShapeProto &source_shape, TypeProto::SparseTensor &target_type)#
void UnionTypeInfo(const TypeProto &source_type, TypeProto &target_type)#
template<typename Axes>
void adjustNegativeAxes(Axes &axes, int rank)#
template<typename Axes>
void checkAxesRange(Axes &axes, int rank)#
template<typename Axes>
void checkDuplicateAxes(Axes &axes, int rank)#
void RNNShapeInference(InferenceContext &ctx)#
void convPoolShapeInference(InferenceContext &ctx, bool use_dilation, bool require_kernel_shape, int input1Idx, int input2Idx)#
void convTransposeShapeInference(InferenceContext &ctx)#
void globalPoolTypeShapeInference(InferenceContext &ctx)#
struct DataPropagationContext#
#include <shape_inference.h>

Subclassed by ONNX_LIGHT_NAMESPACE::shape_inference::DataPropagationContextImpl

Public Functions

virtual const AttributeProto *getAttribute(const std::string &name) const = 0#
virtual size_t getNumInputs() const = 0#
virtual const TypeProto *getInputType(size_t index) const = 0#
virtual size_t getNumOutputs() const = 0#
virtual const TypeProto *getOutputType(size_t index) const = 0#
virtual ~DataPropagationContext() = default#
virtual const TensorShapeProto *getInputData(size_t index) = 0#
virtual void addOutputData(size_t index, TensorShapeProto &&tp) = 0#
class GraphInferencer#
#include <shape_inference.h>

Subclassed by ONNX_LIGHT_NAMESPACE::shape_inference::GraphInferencerImpl

Public Functions

virtual std::vector<const TypeProto*> doInferencing(const std::vector<const TypeProto*> &inputTypes, const std::vector<const TensorProto*> &inputData) = 0#
virtual ~GraphInferencer() = default#
struct InferenceContext#
#include <shape_inference.h>

Subclassed by ONNX_LIGHT_NAMESPACE::shape_inference::InferenceContextImpl

Public Functions

virtual const AttributeProto *getAttribute(const std::string &name) const = 0#
virtual size_t getNumInputs() const = 0#
virtual const TypeProto *getInputType(size_t index) const = 0#
inline virtual bool hasInput(size_t index) const#
virtual const TensorProto *getInputData(size_t index) const = 0#
virtual size_t getNumOutputs() const = 0#
virtual TypeProto *getOutputType(size_t index) = 0#
inline virtual bool hasOutput(size_t index)#
virtual GraphInferencer *getGraphAttributeInferencer(const std::string &attribute_name) = 0#
virtual ~InferenceContext() = default#
virtual const SparseTensorProto *getInputSparseData(size_t index) const = 0#
virtual const TensorShapeProto *getSymbolicInput(size_t index) const = 0#
inline virtual std::string getDisplayName() const#
class InferenceError : public std::runtime_error#
#include <shape_inference.h>

Public Functions

inline explicit InferenceError(const std::string &message)#
inline const char *what() const noexcept override#
inline void AppendContext(const std::string &context)#

Private Members

std::string expanded_message_#
struct ShapeInferenceOptions#
#include <shape_inference.h>

Public Functions

inline explicit ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)#

Public Members

bool check_type#
int error_mode#
bool enable_data_propagation#
class SymbolTable#
#include <shape_inference.h>

Subclassed by ONNX_LIGHT_NAMESPACE::shape_inference::SymbolTableImpl

Public Functions

virtual void addFromGraph(const GraphProto &g) = 0#
inline std::string createNew()#
virtual std::string createNew(const std::string &symbol_prefix) = 0#
virtual ~SymbolTable() = default#