implementation.h#
This header exposes the core C++ shape inference implementation, including
context objects consumed by schema inferencers and helper APIs such as
InferShapes and InferFunctionOutputTypes.
-
namespace ONNX_LIGHT_NAMESPACE
-
namespace shape_inference#
Typedefs
-
using ModelLocalFunctionsMap = std::unordered_map<std::string, const FunctionProto*>#
-
using DataValueMap = std::unordered_map<std::string, TensorShapeProto>#
Functions
-
void checkShapesAndTypes(const TypeProto_Sequence &inferredType, const TypeProto_Sequence &existingType)#
-
template<typename TensorTypeProto>
void GenerateSymbolicShape(TensorTypeProto *inferred_type, SymbolTable &symbol_table)#
-
void MaterializeSymbolicShape(TypeProto *inferred_type, SymbolTable &symbol_table)#
-
void mergeShapesAndTypes(const TypeProto_Tensor &inferred_type, TypeProto_Tensor *existing_type)#
-
void mergeShapesAndTypes(const TypeProto_SparseTensor &inferred_type, TypeProto_SparseTensor *existing_type)#
-
void mergeShapesAndTypes(const TypeProto_Sequence &inferredType, TypeProto_Tensor *existingType)#
-
void InferShapes(GraphProto *g, const std::unordered_map<std::string, int> &opset_imports, const ISchemaRegistry *schema_registry = OpSchemaRegistry::Instance(), const ShapeInferenceOptions &options = ShapeInferenceOptions(), const ModelLocalFunctionsMap &in_model_functions = {})#
ModelLocalFunctionsMap is a map of function id -> model local function proto All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
-
void InferShapes(ModelProto &m, const ISchemaRegistry *schema_registry = OpSchemaRegistry::Instance(), const ShapeInferenceOptions &options = ShapeInferenceOptions(), DataValueMap *generated_shape_data_by_name = nullptr)#
-
void InferShapes(const std::string &model_path, const std::string &save_path = "", const ISchemaRegistry *schema_registry = OpSchemaRegistry::Instance(), const ShapeInferenceOptions &options = ShapeInferenceOptions(), DataValueMap *generated_shape_data_by_name = nullptr)#
-
void InferShapeForFunctionNode(const FunctionProto &func, const ISchemaRegistry *schema_registry, InferenceContext &ctx, const ShapeInferenceOptions &options = ShapeInferenceOptions(), const ModelLocalFunctionsMap &model_local_functions_map = {}, SymbolTable *symbol_table = nullptr, DataValueMap *generated_shape_data_by_name = nullptr)#
ModelLocalFunctionsMap is a map of function id -> model local function proto All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
-
void InferShapeForFunctionNode(const FunctionProto &func_proto, const std::unordered_map<std::string, int> &func_opset_imports, const ISchemaRegistry *schema_registry, InferenceContext &ctx, const ShapeInferenceOptions &options = ShapeInferenceOptions(), const ModelLocalFunctionsMap &model_local_functions_map = {}, SymbolTable *symbol_table = nullptr, DataValueMap *generated_shape_data_by_name = nullptr)#
ModelLocalFunctionsMap is a map of function id -> model local function proto All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
-
std::vector<TypeProto> InferFunctionOutputTypes(const FunctionProto &function_proto, const std::vector<TypeProto> &input_types, const std::vector<AttributeProto> &attributes)#
Apply type-and-shape-inference based checks to a Function body. Returns the inferred types of the outputs of the function. Inference depends on the types of the inputs of the function as well as the attribute values supplied. A TypeProto with value_case() == TypeProto::ValueCase::VALUE_NOT_SET is used for missing optional parameters.
-
void TraverseGraphsToAddExistingSymbols(const GraphProto &g, SymbolTable &symbol_table)#
-
struct DataPropagationContextImpl : public ONNX_LIGHT_NAMESPACE::DataPropagationContext#
- #include <implementation.h>
Public Functions
-
inline DataPropagationContextImpl(NodeProto &n, const std::unordered_map<std::string, TypeProto*> &valueTypesByName, const std::unordered_map<std::string, const TensorProto*> &inputDataByName, DataValueMap &generatedShapeData)#
-
inline virtual const AttributeProto *getAttribute(const std::string &name) const override#
-
inline virtual size_t getNumInputs() const override#
-
inline virtual size_t getNumOutputs() const override#
-
template<typename INTEGER>
inline void vectorToTensorShapeProto(const std::vector<INTEGER> &input_vals, TensorShapeProto &converted_tsp) const#
-
inline virtual const TensorShapeProto *getInputData(size_t index) override#
-
inline virtual void addOutputData(size_t index, TensorShapeProto &&tsp) override#
Public Members
-
std::vector<const TensorProto*> allInputData_#
-
DataValueMap &generatedShapeData_#
-
std::unordered_map<std::string, const AttributeProto*> attributesByName_#
-
inline DataPropagationContextImpl(NodeProto &n, const std::unordered_map<std::string, TypeProto*> &valueTypesByName, const std::unordered_map<std::string, const TensorProto*> &inputDataByName, DataValueMap &generatedShapeData)#
-
struct GraphInferenceContext#
- #include <implementation.h>
Public Functions
-
inline GraphInferenceContext(const std::unordered_map<std::string, TypeProto*> &outer_scope_value_types_by_name_in, std::unordered_map<std::string, int> opset_imports_in, SymbolTable *symbol_table_in = nullptr, const ModelLocalFunctionsMap &model_local_functions_in = {}, const ISchemaRegistry *schema_registry_in = OpSchemaRegistry::Instance(), DataValueMap *generated_shape_data_by_name_in = nullptr, const int ir_version_in = IR_VERSION)#
Public Members
-
SymbolTable *symbol_table#
-
const ModelLocalFunctionsMap &model_local_functions#
-
const ISchemaRegistry *schema_registry#
-
DataValueMap *generated_shape_data_by_name#
-
const int ir_version#
-
inline GraphInferenceContext(const std::unordered_map<std::string, TypeProto*> &outer_scope_value_types_by_name_in, std::unordered_map<std::string, int> opset_imports_in, SymbolTable *symbol_table_in = nullptr, const ModelLocalFunctionsMap &model_local_functions_in = {}, const ISchemaRegistry *schema_registry_in = OpSchemaRegistry::Instance(), DataValueMap *generated_shape_data_by_name_in = nullptr, const int ir_version_in = IR_VERSION)#
-
class GraphInferencerImpl : public ONNX_LIGHT_NAMESPACE::GraphInferencer#
- #include <implementation.h>
Public Functions
-
inline GraphInferencerImpl(GraphProto &g, GraphInferenceContext &context)#
-
inline GraphInferencerImpl(GraphProto &g, GraphInferenceContext &context, const ShapeInferenceOptions &options)#
-
inline GraphInferencerImpl(GraphProto &g, GraphInferenceContext &context)#
-
struct InferenceContextImpl : public ONNX_LIGHT_NAMESPACE::InferenceContext#
- #include <implementation.h>
Public Functions
-
inline InferenceContextImpl(NodeProto &n, const std::unordered_map<std::string, TypeProto*> &valueTypesByName, const std::unordered_map<std::string, const TensorProto*> &inputDataByName, const std::unordered_map<std::string, const SparseTensorProto*> &inputSparseDataByName, const ShapeInferenceOptions &options, DataValueMap *generatedShapeData = nullptr, GraphInferenceContext *graphInferenceContext = nullptr)#
-
inline virtual const AttributeProto *getAttribute(const std::string &name) const override#
-
inline virtual size_t getNumInputs() const override#
-
inline virtual const TensorProto *getInputData(size_t index) const override#
-
inline virtual const TensorShapeProto *getSymbolicInput(size_t index) const override#
-
inline virtual const SparseTensorProto *getInputSparseData(size_t index) const override#
-
inline virtual size_t getNumOutputs() const override#
-
inline virtual GraphInferencer *getGraphAttributeInferencer(const std::string &attr_name) override#
Public Members
-
std::vector<const TensorProto*> allInputData_#
-
std::vector<const SparseTensorProto*> allInputSparseData_#
-
std::vector<const TensorShapeProto*> allShapeInputData_#
-
std::unordered_map<std::string, const AttributeProto*> attributesByName_#
-
std::unordered_map<std::string, GraphProto*> graphProtoAttributesByName_#
-
GraphInferenceContext *graphInferenceContext_#
-
mutable std::unordered_map<std::string, std::unique_ptr<GraphInferencer>> graphAttributeInferencers_#
-
ShapeInferenceOptions options_#
-
inline InferenceContextImpl(NodeProto &n, const std::unordered_map<std::string, TypeProto*> &valueTypesByName, const std::unordered_map<std::string, const TensorProto*> &inputDataByName, const std::unordered_map<std::string, const SparseTensorProto*> &inputSparseDataByName, const ShapeInferenceOptions &options, DataValueMap *generatedShapeData = nullptr, GraphInferenceContext *graphInferenceContext = nullptr)#
-
class SymbolTableImpl : public ONNX_LIGHT_NAMESPACE::SymbolTable#
- #include <implementation.h>
Public Functions
-
SymbolTableImpl() = default#
-
inline virtual void addFromGraph(const GraphProto &g) override#
Private Functions
-
template<typename TensorTypeProto>
inline void AddExistingSymbolicDims(const TensorTypeProto &tensorType)#
-
inline void AddExistingSymbolicDims(const utils::RepeatedProtoField<ValueInfoProto> &protos)#
-
SymbolTableImpl() = default#
-
using ModelLocalFunctionsMap = std::unordered_map<std::string, const FunctionProto*>#
-
namespace shape_inference#