shape_inference.h#
Core declarations for operator type-and-shape inference, including
onnx::ShapeInferenceOptions,
onnx::InferenceContext, and helper routines such as
onnx::propagateElemTypeFromInputToOutput().
-
namespace ONNX_LIGHT_NAMESPACE
Typedefs
-
using Dim = TensorShapeProto::Dimension#
-
using InferenceFunction = std::function<void(InferenceContext&)>#
-
using DataPropagationFunction = std::function<void(DataPropagationContext&)>#
Functions
-
inline void dummyInferenceFunction(InferenceContext&)#
-
inline void dummyDataPropagationFunction(DataPropagationContext&)#
-
template<typename T>
inline bool getRepeatedAttribute(InferenceContext &ctx, const std::string &attr_name, std::vector<T> &values)#
-
inline const AttributeProto &getRequiredAttribute(const InferenceContext &ctx, const std::string &name)#
-
inline int64_t getRequiredAttributeInt(const InferenceContext &ctx, const std::string &name)#
-
inline int64_t getAttribute(const InferenceContext &ctx, const std::string &attributeName, int64_t defaultValue)#
-
inline int64_t getAttribute(const DataPropagationContext &ctx, const std::string &attributeName, int64_t defaultValue)#
-
inline std::string getAttribute(const InferenceContext &ctx, const std::string &attributeName, const std::string &defaultValue)#
-
inline TensorShapeProto::Dimension operator*(const TensorShapeProto::Dimension &dim1, const TensorShapeProto::Dimension &dim2)#
-
std::pair<int, int> getAttributeProtoElemTypeAndLength(const AttributeProto *attr_proto)#
-
std::pair<int, int> getAttributeElementTypeAndLength(const InferenceContext &ctx, const std::initializer_list<std::string> &attribute_names)#
-
inline TensorShapeProto::Dimension operator*(const TensorShapeProto::Dimension &dim1, int64_t dim2)#
-
inline TensorShapeProto::Dimension operator/(const TensorShapeProto::Dimension &dim1, int64_t dim2)#
-
inline TensorShapeProto::Dimension multiplyDims(const TensorShapeProto &shape, int from, int upto_exclusive)#
-
inline void setTensorElementType(int32_t elem_type, TypeProto::ValueCase value_case, TypeProto &type)#
-
void propagateElemTypeFromInputToOutput(InferenceContext &ctx, size_t inputIndex, size_t outputIndex)#
-
void propagateElemTypeFromTensorInputToOutput(InferenceContext &ctx, size_t inputIndex, size_t outputIndex)#
-
inline void propagateElemTypeFromDtypeToOutput(InferenceContext &ctx, const int data_type, size_t outputIndex, TypeProto::ValueCase expected_value_case)#
-
inline void propagateElemTypeFromDtypeToOutput(InferenceContext &ctx, const int data_type, size_t outputIndex)#
-
inline void propagateElemTypeFromDtypeToOutput(InferenceContext &ctx, const AttributeProto *attr, size_t outputIndex)#
-
inline const TensorShapeProto &getInputShape(const InferenceContext &ctx, size_t n)#
-
inline const TensorShapeProto *getOptionalInputShape(const InferenceContext &ctx, size_t n)#
-
inline void appendSingleDimCopiedFromInputTypeToOutputType(InferenceContext &ctx, size_t inputIndex, size_t outputIndex, size_t fromDimIndex)#
-
inline void propagateShapeFromInputToOutput(InferenceContext &ctx, size_t inputIndex, size_t outputIndex)#
-
inline void propagateShapeAndTypeFromFirstInput(InferenceContext &ctx)#
-
inline void updateOutputElemType(InferenceContext &ctx, size_t outputIndex, int32_t elemType, TypeProto::ValueCase expected_type)#
-
inline void updateOutputElemType(InferenceContext &ctx, size_t outputIndex, int32_t elemType)#
-
inline void propagateElemTypeFromAttributeToOutput(InferenceContext &ctx, const std::string &attributeName, size_t outputIndex, TypeProto::ValueCase expected_type, TensorProto::DataType default_value = TensorProto::UNDEFINED)#
-
inline void propagateElemTypeFromAttributeToOutput(InferenceContext &ctx, const std::string &attributeName, size_t outputIndex, TensorProto::DataType default_value = TensorProto::UNDEFINED)#
-
inline TensorShapeProto *getTensorMutableShape(TypeProto::ValueCase value_case, TypeProto &type)#
-
inline TensorShapeProto *getOutputShape(InferenceContext &ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType)#
-
inline void appendDim(TensorShapeProto *shape, int64_t dim_value)#
-
inline void updateOutputShape(InferenceContext &ctx, size_t outputIndex, const TensorShapeProto &shape, TypeProto::ValueCase default_type = TypeProto::kTensorType)#
-
inline void updateOutputShape(InferenceContext &ctx, size_t outputIndex, const TensorProto &tensorProto, TypeProto::ValueCase default_type = TypeProto::kTensorType)#
-
inline void updateOutputShape(InferenceContext &ctx, size_t outputIndex, std::initializer_list<TensorShapeProto::Dimension> dims, TypeProto::ValueCase default_type = TypeProto::kTensorType)#
-
TensorShapeProto getShapeInput(const InferenceContext &ctx, size_t input_index, bool &found)#
-
TensorShapeProto getShapeInput(const InferenceContext &ctx, size_t input_index, bool fail_if_negative_value, bool &found)#
-
inline void propagateShapeFromAttributeToOutput(InferenceContext &ctx, const std::string &attributeName, size_t outputIndex, TypeProto::ValueCase default_type = TypeProto::kTensorType)#
-
inline void multidirectionalBroadcastShapeInference(const std::vector<const TensorShapeProto*> &shapes, TensorShapeProto &resultShape)#
-
inline void bidirectionalBroadcastShapeInference(const TensorShapeProto &shapeL, const TensorShapeProto &shapeR, TensorShapeProto &resultShape)#
-
inline void mergeInDimensionInfo(const TensorShapeProto::Dimension &source_dim, TensorShapeProto::Dimension &target_dim, int dim_index)#
-
void mergeInShapeInfo(const TensorShapeProto &source_shape, TypeProto::Tensor &target_type)#
-
void mergeInShapeInfo(const TensorShapeProto &source_shape, TypeProto::SparseTensor &target_type)#
-
void mergeInShapeInfo(const TypeProto::SparseTensor &source, TypeProto::SparseTensor &target)#
-
inline void checkInputRank(const InferenceContext &ctx, size_t input_index, int expected_rank)#
-
inline void checkDimEquality(int64_t value1, int64_t value2)#
-
inline void unifyInputDim(const InferenceContext &ctx, size_t input_index, int dim_index, Dim &dim)#
-
void UnionShapeInfo(const TensorShapeProto &source_shape, TypeProto::Tensor &target_type)#
-
void UnionShapeInfo(const TensorShapeProto &source_shape, TypeProto::SparseTensor &target_type)#
-
void RNNShapeInference(InferenceContext &ctx)#
-
void convPoolShapeInference(InferenceContext &ctx, bool use_dilation, bool require_kernel_shape, int input1Idx, int input2Idx)#
-
void convTransposeShapeInference(InferenceContext &ctx)#
-
void globalPoolTypeShapeInference(InferenceContext &ctx)#
-
struct DataPropagationContext#
- #include <shape_inference.h>
Subclassed by ONNX_LIGHT_NAMESPACE::shape_inference::DataPropagationContextImpl
Public Functions
-
virtual const AttributeProto *getAttribute(const std::string &name) const = 0#
-
virtual size_t getNumInputs() const = 0#
-
virtual size_t getNumOutputs() const = 0#
-
virtual ~DataPropagationContext() = default#
-
virtual const TensorShapeProto *getInputData(size_t index) = 0#
-
virtual void addOutputData(size_t index, TensorShapeProto &&tp) = 0#
-
virtual const AttributeProto *getAttribute(const std::string &name) const = 0#
-
class GraphInferencer#
- #include <shape_inference.h>
Subclassed by ONNX_LIGHT_NAMESPACE::shape_inference::GraphInferencerImpl
-
struct InferenceContext#
- #include <shape_inference.h>
Subclassed by ONNX_LIGHT_NAMESPACE::shape_inference::InferenceContextImpl
Public Functions
-
virtual const AttributeProto *getAttribute(const std::string &name) const = 0#
-
virtual size_t getNumInputs() const = 0#
-
inline virtual bool hasInput(size_t index) const#
-
virtual const TensorProto *getInputData(size_t index) const = 0#
-
virtual size_t getNumOutputs() const = 0#
-
inline virtual bool hasOutput(size_t index)#
-
virtual GraphInferencer *getGraphAttributeInferencer(const std::string &attribute_name) = 0#
-
virtual ~InferenceContext() = default#
-
virtual const SparseTensorProto *getInputSparseData(size_t index) const = 0#
-
virtual const TensorShapeProto *getSymbolicInput(size_t index) const = 0#
-
virtual const AttributeProto *getAttribute(const std::string &name) const = 0#
-
class InferenceError : public std::runtime_error#
- #include <shape_inference.h>
Public Functions
-
inline const char *what() const noexcept override#
-
inline const char *what() const noexcept override#
-
struct ShapeInferenceOptions#
- #include <shape_inference.h>
Public Functions
-
inline explicit ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)#
-
inline explicit ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)#
-
class SymbolTable#
- #include <shape_inference.h>
Subclassed by ONNX_LIGHT_NAMESPACE::shape_inference::SymbolTableImpl
-
using Dim = TensorShapeProto::Dimension#