implementation.h#

This header exposes the core C++ shape inference implementation, including context objects consumed by schema inferencers and helper APIs such as InferShapes and InferFunctionOutputTypes.

namespace ONNX_LIGHT_NAMESPACE
namespace shape_inference#

Typedefs

using ModelLocalFunctionsMap = std::unordered_map<std::string, const FunctionProto*>#
using DataValueMap = std::unordered_map<std::string, TensorShapeProto>#

Functions

std::string GetValueCaseString(const TypeProto &type)#
void checkShapesAndTypes(const TypeProto_Sequence &inferredType, const TypeProto_Sequence &existingType)#
void checkShapesAndTypes(const TypeProto &inferred_type, const TypeProto &existing_type)#
template<typename TensorTypeProto>
void GenerateSymbolicShape(TensorTypeProto *inferred_type, SymbolTable &symbol_table)#
void MaterializeSymbolicShape(TypeProto *inferred_type, SymbolTable &symbol_table)#
void mergeShapesAndTypes(const TypeProto_Tensor &inferred_type, TypeProto_Tensor *existing_type)#
void mergeShapesAndTypes(const TypeProto_SparseTensor &inferred_type, TypeProto_SparseTensor *existing_type)#
void mergeShapesAndTypes(const TypeProto_Sequence &inferredType, TypeProto_Tensor *existingType)#
void mergeShapesAndTypes(const TypeProto &inferred_type, TypeProto *existing_type)#
void InferShapes(GraphProto *g, const std::unordered_map<std::string, int> &opset_imports, const ISchemaRegistry *schema_registry = OpSchemaRegistry::Instance(), const ShapeInferenceOptions &options = ShapeInferenceOptions(), const ModelLocalFunctionsMap &in_model_functions = {})#

ModelLocalFunctionsMap is a map of function id -> model local function proto All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>

void InferShapes(ModelProto &m, const ISchemaRegistry *schema_registry = OpSchemaRegistry::Instance(), const ShapeInferenceOptions &options = ShapeInferenceOptions(), DataValueMap *generated_shape_data_by_name = nullptr)#
void InferShapes(const std::string &model_path, const std::string &save_path = "", const ISchemaRegistry *schema_registry = OpSchemaRegistry::Instance(), const ShapeInferenceOptions &options = ShapeInferenceOptions(), DataValueMap *generated_shape_data_by_name = nullptr)#
void InferShapeForFunctionNode(const FunctionProto &func, const ISchemaRegistry *schema_registry, InferenceContext &ctx, const ShapeInferenceOptions &options = ShapeInferenceOptions(), const ModelLocalFunctionsMap &model_local_functions_map = {}, SymbolTable *symbol_table = nullptr, DataValueMap *generated_shape_data_by_name = nullptr)#

ModelLocalFunctionsMap is a map of function id -> model local function proto All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>

void InferShapeForFunctionNode(const FunctionProto &func_proto, const std::unordered_map<std::string, int> &func_opset_imports, const ISchemaRegistry *schema_registry, InferenceContext &ctx, const ShapeInferenceOptions &options = ShapeInferenceOptions(), const ModelLocalFunctionsMap &model_local_functions_map = {}, SymbolTable *symbol_table = nullptr, DataValueMap *generated_shape_data_by_name = nullptr)#

ModelLocalFunctionsMap is a map of function id -> model local function proto All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>

std::vector<TypeProto> InferFunctionOutputTypes(const FunctionProto &function_proto, const std::vector<TypeProto> &input_types, const std::vector<AttributeProto> &attributes)#

Apply type-and-shape-inference based checks to a Function body. Returns the inferred types of the outputs of the function. Inference depends on the types of the inputs of the function as well as the attribute values supplied. A TypeProto with value_case() == TypeProto::ValueCase::VALUE_NOT_SET is used for missing optional parameters.

std::string GetErrorWithNodeInfo(const NodeProto &n, const std::runtime_error &err)#
void TraverseGraphsToAddExistingSymbols(const GraphProto &g, SymbolTable &symbol_table)#
struct DataPropagationContextImpl : public ONNX_LIGHT_NAMESPACE::DataPropagationContext#
#include <implementation.h>

Public Functions

inline DataPropagationContextImpl(NodeProto &n, const std::unordered_map<std::string, TypeProto*> &valueTypesByName, const std::unordered_map<std::string, const TensorProto*> &inputDataByName, DataValueMap &generatedShapeData)#
inline virtual const AttributeProto *getAttribute(const std::string &name) const override#
inline virtual size_t getNumInputs() const override#
inline virtual const TypeProto *getInputType(size_t index) const override#
inline virtual size_t getNumOutputs() const override#
inline virtual const TypeProto *getOutputType(size_t index) const override#
template<typename INTEGER>
inline void vectorToTensorShapeProto(const std::vector<INTEGER> &input_vals, TensorShapeProto &converted_tsp) const#
inline virtual const TensorShapeProto *getInputData(size_t index) override#
inline virtual void addOutputData(size_t index, TensorShapeProto &&tsp) override#

Public Members

std::vector<const TensorProto*> allInputData_#
std::unordered_map<size_t, std::string> inputIndexToNameMap_#
std::unordered_map<size_t, std::string> outputIndexToNameMap_#
std::vector<const TypeProto*> allInputTypes_#
std::vector<TypeProto> allOutputTypes_#
DataValueMap &generatedShapeData_#
std::unordered_map<std::string, const AttributeProto*> attributesByName_#
struct GraphInferenceContext#
#include <implementation.h>

Public Functions

inline GraphInferenceContext(const std::unordered_map<std::string, TypeProto*> &outer_scope_value_types_by_name_in, std::unordered_map<std::string, int> opset_imports_in, SymbolTable *symbol_table_in = nullptr, const ModelLocalFunctionsMap &model_local_functions_in = {}, const ISchemaRegistry *schema_registry_in = OpSchemaRegistry::Instance(), DataValueMap *generated_shape_data_by_name_in = nullptr, const int ir_version_in = IR_VERSION)#

Public Members

const std::unordered_map<std::string, TypeProto*> *outer_scope_value_types_by_name#
const std::unordered_map<std::string, int> opset_imports#
SymbolTable *symbol_table#
const ModelLocalFunctionsMap &model_local_functions#
const ISchemaRegistry *schema_registry#
DataValueMap *generated_shape_data_by_name#
const int ir_version#
class GraphInferencerImpl : public ONNX_LIGHT_NAMESPACE::GraphInferencer#
#include <implementation.h>

Public Functions

inline GraphInferencerImpl(GraphProto &g, GraphInferenceContext &context)#
inline GraphInferencerImpl(GraphProto &g, GraphInferenceContext &context, const ShapeInferenceOptions &options)#
virtual std::vector<const TypeProto*> doInferencing(const std::vector<const TypeProto*> &input_types, const std::vector<const TensorProto*> &input_data) override#

Private Members

GraphProto *g_#
GraphInferenceContext *context_#
ShapeInferenceOptions options_#
struct InferenceContextImpl : public ONNX_LIGHT_NAMESPACE::InferenceContext#
#include <implementation.h>

Public Functions

inline InferenceContextImpl(NodeProto &n, const std::unordered_map<std::string, TypeProto*> &valueTypesByName, const std::unordered_map<std::string, const TensorProto*> &inputDataByName, const std::unordered_map<std::string, const SparseTensorProto*> &inputSparseDataByName, const ShapeInferenceOptions &options, DataValueMap *generatedShapeData = nullptr, GraphInferenceContext *graphInferenceContext = nullptr)#
inline virtual const AttributeProto *getAttribute(const std::string &name) const override#
inline virtual size_t getNumInputs() const override#
inline virtual const TypeProto *getInputType(size_t index) const override#
inline virtual const TensorProto *getInputData(size_t index) const override#
inline virtual const TensorShapeProto *getSymbolicInput(size_t index) const override#
inline virtual const SparseTensorProto *getInputSparseData(size_t index) const override#
inline virtual size_t getNumOutputs() const override#
inline virtual TypeProto *getOutputType(size_t index) override#
inline virtual GraphInferencer *getGraphAttributeInferencer(const std::string &attr_name) override#
inline virtual std::string getDisplayName() const override#

Public Members

std::vector<const TensorProto*> allInputData_#
std::vector<const SparseTensorProto*> allInputSparseData_#
std::vector<const TensorShapeProto*> allShapeInputData_#
std::unordered_map<std::string, const AttributeProto*> attributesByName_#
std::unordered_map<std::string, GraphProto*> graphProtoAttributesByName_#
std::vector<const TypeProto*> allInputTypes_#
std::vector<TypeProto> allOutputTypes_#
GraphInferenceContext *graphInferenceContext_#
mutable std::unordered_map<std::string, std::unique_ptr<GraphInferencer>> graphAttributeInferencers_#
ShapeInferenceOptions options_#
NodeProto *node_#
class SymbolTableImpl : public ONNX_LIGHT_NAMESPACE::SymbolTable#
#include <implementation.h>

Public Functions

SymbolTableImpl() = default#
inline virtual void addFromGraph(const GraphProto &g) override#
inline virtual std::string createNew(const std::string &symbol_prefix) override#

Private Functions

template<typename TensorTypeProto>
inline void AddExistingSymbolicDims(const TensorTypeProto &tensorType)#
inline void AddExistingSymbolicDims(const TypeProto &typeProto)#
inline void AddExistingSymbolicDims(const utils::RepeatedProtoField<ValueInfoProto> &protos)#

Private Members

unsigned int index_ = {0}#
std::unordered_set<std::string> existing_symbols#