schema.h#

Operator schema definitions and registration utilities, including onnx::OpSchema, onnx::OpSchemaRegistry, and onnx::OpSchemaRegistry::DomainToVersionRange.

Defines

fail_schema(...)#
ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName)#
ONNX_OPERATOR_SET_SCHEMA(name, ver, impl)#
ONNX_ML_OPERATOR_SET_SCHEMA(name, ver, impl)#
ONNX_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl)#
ONNX_PREVIEW_OPERATOR_SET_SCHEMA(name, ver, impl)#
ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl)#
ONNX_DBG_INCREMENT_COUNT_IN_OPSETS()#
ONNX_OPERATOR_SET_SCHEMA_DEBUG_VARIABLE(name, domain, ver, dbg_included_in_static_opset)#
ONNX_OPERATOR_SET_SCHEMA_EX(name, domain, domain_str, ver, dbg_included_in_static_opset, impl)#
ONNX_DBG_GET_COUNT_IN_OPSETS()#
ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name)#
ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(ver, name)#
ONNX_OPERATOR_SCHEMA(name)#
ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name)#
ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)#
GET_OP_DOC_STR(doc_str)#
POPULATE_OP_DOC_STR(DocPopulatorCode)#
namespace ONNX_LIGHT_NAMESPACE

Typedefs

using FunctionBodyQueryFunction = std::function<bool(FunctionBodyBuildContext&)>#
using ContextDependentFunctionBodyBuilder = std::function<bool(const FunctionBodyBuildContext&, const OpSchema&, FunctionProto&)>#
using OperatorSetVersion = int#
using DataTypeSet = std::unordered_set<DataType>#
using TypeConstraintMap = std::unordered_map<std::string, std::pair<DataTypeSet, std::string>>#
using OpName_Domain_Version_Schema_Map = std::unordered_map<std::string, std::unordered_map<std::string, std::map<OperatorSetVersion, OpSchema>>>#

Functions

void RegisterAllOnnxOperatorSchemas()#
void RegisterSchema(const OpSchema &schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true, bool fail_with_exception = false)#
void RegisterSchema(OpSchema &&schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true, bool fail_with_exception = false)#
void DeregisterSchema(const std::string &op_type, int version, const std::string &domain)#
template<class T>
void RegisterOpSetSchema(int opset_version_to_load = 0, bool fail_duplicate_schema = true)#
template<typename T>
OpSchema GetOpSchema()#
size_t ReplaceAll(std::string &s, const char *from, const char *to)#
inline std::string GenerateOptionalArgumentsDoc()#
inline std::string GenerateBroadcastingDocMul()#
inline std::string GenerateBroadcastingDocUni(const char *from, const char *to)#
class DbgOperatorSetTracker#
#include <schema.h>

Public Functions

inline size_t IncrementCount()#
inline size_t GetCount() const#

Public Static Functions

static DbgOperatorSetTracker &Instance()#

Private Members

size_t count_ = 0#
struct FunctionBodyBuildContext#
#include <schema.h>

Subclassed by ONNX_LIGHT_NAMESPACE::FunctionBodyBuildContextImpl

Public Functions

virtual const AttributeProto *getAttribute(const std::string &name) const = 0#
virtual bool hasInput(int inputIndex) const = 0#
virtual bool hasOutput(int inputIndex) const = 0#
virtual const TypeProto *getInputType(int inputIndex) const = 0#
virtual ~FunctionBodyBuildContext() = default#
struct FunctionBodyBuildContextImpl : public ONNX_LIGHT_NAMESPACE::FunctionBodyBuildContext#
#include <schema.h>

Public Functions

inline explicit FunctionBodyBuildContextImpl(const NodeProto &node_proto, const std::vector<TypeProto> &input_types = {})#
inline virtual const AttributeProto *getAttribute(const std::string &name) const override#
inline virtual bool hasInput(int inputIndex) const override#
inline virtual bool hasOutput(int inputIndex) const override#
inline virtual const TypeProto *getInputType(int inputIndex) const override#

Public Members

std::unordered_map<std::string, const AttributeProto*> attributesByName_#
NodeProto node_proto_#
std::vector<TypeProto> input_types_#
class ISchemaRegistry#
#include <schema.h>

Subclassed by ONNX_LIGHT_NAMESPACE::OpSchemaRegistry

Public Functions

virtual ~ISchemaRegistry() = default#
virtual const OpSchema *GetSchema(const std::string &key, const int maxInclusiveVersion, const std::string &domain = ONNX_DOMAIN) const = 0#
inline virtual const OpSchema *GetSchema(const utils::String &key, const int maxInclusiveVersion, const utils::String &domain) const#
inline virtual const OpSchema *GetSchema(const utils::RefString &key, const int maxInclusiveVersion, const utils::RefString &domain) const#
class OpSchema#
#include <schema.h>

A class to record the schema of an op.

OpSchema records the common interface of an op specified by its name.

To register an OpSchema, one can use the macro ONNX_OPERATOR_SCHEMA(name) and then append the various functions in the class. For example, for an op that takes in two inputs, one output, and the first input and output could be in-place, can be written as

ONNX_OPERATOR_SCHEMA(name)
    .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}});
To manufacture methods that may be used to register an OpSchema non-statically, the following may be used:
ONNX_OPERATOR_SET_SCHEMA(name, version, OpSchema()
    .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}}));

Public Types

enum FormalParameterOption#

Values:

enumerator Single#
enumerator Optional#
enumerator Variadic#
enum DifferentiationCategory#

Values:

enumerator Unknown#
enumerator Differentiable#
enumerator NonDifferentiable#
enum class NodeDeterminism : uint8_t#

Values:

enumerator Unknown#
enumerator NonDeterministic#
enumerator Deterministic#
enum class SupportType : uint8_t#

Values:

enumerator COMMON#
enumerator EXPERIMENTAL#

Public Functions

inline OpSchema()#
inline OpSchema(std::string name, std::string file, int line)#
inline const std::string &file() const#

Returns the file that the op schema is registered from.

inline int line() const#

Returns the line in file that the op schema is registered from.

inline SupportType support_level() const#

Returns the support level of the op schema.

inline const char *doc() const#

Returns the docstring of the op schema.

void CheckInputOutputType(struct InferenceContext&) const#
void Verify(const NodeProto &node) const#

Verifies if a NodeProto matches the pattern specified in the schema.

OpSchema &SinceVersion(OperatorSetVersion n)#

The earliest operator set version which this operator was present in. If an operator has had no BC-breaking changes, this is simply the first operator set the operator was a member of; if it has had BC-breaking changes, then for the semantics /as described/ in the OpSchema entry, this version describes the operator set which introduced the BC-breaking change.

For example, suppose op Foo was added in v3, and had a BC-breaking change in v6. Then there will be an op schema entry for Foo with SinceVersion(3), and another, updated op schema entry for Foo with SinceVersion(6).

OpSchema &Deprecate()#

Marks this op as deprecated as of it’s since_version. This will cause the Schema() lookup functions to return nullptr when the version is in the deprecated range.

inline bool Deprecated() const#
OpSchema &NumInputs(std::unordered_set<int> allowed_input_nums)#

Input could be one of the values specified in allowed_input_nums.

OpSchema &NumOutputs(std::unordered_set<int> allowed_output_nums)#

Output could be one of the values specified in allowed_output_nums.

OpSchema &TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction)#
InferenceFunction GetTypeAndShapeInferenceFunction() const#
OpSchema &PartialDataPropagationFunction(DataPropagationFunction dataPropagationFunction)#
inline DataPropagationFunction GetDataPropagationFunction() const#
OpSchema &SetSupportLevel(SupportType supportType)#
OpSchema &SetDoc(const char *doc)#
OpSchema &SetDoc(const std::string &doc)#
OpSchema &SetName(const char *name)#
OpSchema &SetName(std::string name)#
OpSchema &SetLocation(const char *file, int line)#
OpSchema &SetLocation(std::string file, int line)#
OpSchema &SetDomain(const char *domain)#
OpSchema &SetDomain(std::string domain)#
OpSchema &Attr(Attribute attr)#
OpSchema &Attr(const char *name, const char *description, AttributeProto::AttributeType type, const char *defaultValue)#
OpSchema &Attr(std::string name, std::string description, std::string conditionExplanation, AttributeProto::AttributeType attr_type)#
OpSchema &Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required = true)#
OpSchema &Attr(const char *name, const char *description, AttributeProto::AttributeType type, bool required = true)#
OpSchema &AllowUncheckedAttributes()#
OpSchema &Input(int n, FormalParameter formal_parameter)#
OpSchema &Input(int n, std::string name, const std::string &description, std::string type_str, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
OpSchema &Input(int n, const char *name, const char *description, const char *type_str, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
OpSchema &Output(int n, FormalParameter formal_parameter)#
OpSchema &Output(int n, std::string name, const std::string &description, std::string type_str, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
OpSchema &Output(int n, const char *name, const char *description, const char *type_str, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
OpSchema &TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description)#
OpSchema &TypeConstraint(const char *type_str, std::initializer_list<const char*> constraints, const char *description)#
OpSchema &FillUsing(const std::function<void(OpSchema&)> &populator)#
inline const std::string &domain() const#
inline const std::unordered_map<std::string, Attribute> &attributes() const#
inline const std::vector<FormalParameter> &inputs() const#
inline const std::vector<FormalParameter> &outputs() const#
inline const std::vector<TypeConstraintParam> &typeConstraintParams() const#
inline const TypeConstraintMap &typeConstraintMap() const#
inline const std::string &Name() const#
inline OperatorSetVersion SinceVersion() const#
inline int since_version() const#
inline bool deprecated() const#
inline int min_input() const#
inline int max_input() const#
inline int min_output() const#
inline int max_output() const#
inline bool has_type_and_shape_inference_function() const#
inline bool has_data_propagation_function() const#
std::vector<int> function_opset_versions() const#
inline bool HasFunction() const#
OpSchema &FunctionBody(const std::vector<NodeProto> &func_nodes, int opset_version = kUninitializedSinceVersion)#
OpSchema &FunctionBody(const std::vector<NodeProto> &func_nodes, const std::vector<OperatorSetIdProto> &opsets, int opset_version = kUninitializedSinceVersion)#
OpSchema &FunctionBody(const char *func_body, int opset_version = kUninitializedSinceVersion)#
const FunctionProto *GetFunction(int requested_opset_version = OpSchema::kUninitializedSinceVersion, bool validate = false) const#
std::vector<int> context_dependent_function_opset_versions() const#
bool HasContextDependentFunction() const#
bool HasContextDependentFunctionWithOpsetVersion(int opset_version) const#
OpSchema &SetContextDependentFunctionBodyBuilder(ContextDependentFunctionBodyBuilder, int opset_version = kUninitializedSinceVersion)#
bool BuildContextDependentFunction(const FunctionBodyBuildContext &ctx, FunctionProto &function_proto, int requested_opset_version = OpSchema::kUninitializedSinceVersion) const#
void Finalize()#
void BuildFunction(FunctionProto &function_body) const#
NodeDeterminism GetNodeDeterminism() const#
OpSchema &SetNodeDeterminism(NodeDeterminism node_determinism)#

Public Static Functions

static inline const std::vector<std::string> &numeric_types_for_math_reduction_ir10()#
static const std::vector<std::string> &numeric_types_for_math_reduction_ir9()#
static const std::vector<std::string> &numeric_types_for_math_reduction_ir4()#
static const std::vector<std::string> &numeric_types_for_math_reduction()#
static const std::vector<std::string> &all_numeric_types_ir13()#
static const std::vector<std::string> &all_numeric_types_ir12()#
static const std::vector<std::string> &all_numeric_types_ir11()#
static const std::vector<std::string> &all_numeric_types_ir10()#
static const std::vector<std::string> &all_numeric_types_ir9()#
static const std::vector<std::string> &all_numeric_types_ir4()#
static const std::vector<std::string> &all_numeric_types()#
static const std::vector<std::string> &all_numeric_sequence_types()#
static const std::vector<std::string> &all_tensor_types()#
static const std::vector<std::string> &all_tensor_types_ir4()#
static const std::vector<std::string> &all_non_complex_numeric_types_plus_bool_ir4()#
static const std::vector<std::string> &all_float_types_ir4()#
static const std::vector<std::string> &all_float_types_plus_Xint8_ir4()#
static const std::vector<std::string> &all_float_types_ir9()#
static inline const std::vector<std::string> &all_float_types_ir10()#
static const std::vector<std::string> &all_tensor_types_ir9()#
static const std::vector<std::string> &all_tensor_types_ir10()#
static const std::vector<std::string> &all_non_complex_tensor_types_ir10()#
static const std::vector<std::string> &all_tensor_types_ir11()#
static const std::vector<std::string> &all_non_complex_tensor_types_ir11()#
static const std::vector<std::string> &all_tensor_types_ir12()#
static const std::vector<std::string> &all_non_complex_tensor_types_ir12()#
static const std::vector<std::string> &all_tensor_types_ir13()#
static const std::vector<std::string> &all_non_complex_tensor_types_ir13()#
static const std::vector<std::string> &all_non_string_tensor_types_ir13()#
static const std::vector<std::string> &all_tensor_sequence_types()#
static const std::vector<std::string> &all_tensor_sequence_types_ir4()#
static const std::vector<std::string> &all_tensor_sequence_types_ir9()#
static const std::vector<std::string> &all_tensor_sequence_types_ir10()#
static const std::vector<std::string> &all_tensor_sequence_types_ir11()#
static const std::vector<std::string> &all_tensor_sequence_types_ir12()#
static const std::vector<std::string> &all_tensor_sequence_types_ir13()#
static const std::vector<std::string> &all_optional_types()#
static const std::vector<std::string> &all_optional_types_ir4()#
static const std::vector<std::string> &all_optional_types_ir9()#
static const std::vector<std::string> &all_optional_types_ir10()#
static const std::vector<std::string> &all_optional_types_ir11()#
static const std::vector<std::string> &all_optional_types_ir12()#
static const std::vector<std::string> &all_optional_types_ir13()#

Public Static Attributes

static constexpr int kUninitializedSinceVersion = -1#

Private Functions

void ParseAndSetTypes(std::vector<OpSchema::FormalParameter> *formal_parameters)#
bool ValidateReferencedOpsInFunction(const FunctionProto *function, int requested_opset_version, int function_since_version, std::unordered_set<std::string> *updated_ops = nullptr) const#
void UpdateFunctionProtoOpsetImportVersion(FunctionProto &function_proto, int opset_version) const#
std::string VerifyFailPrefix(std::string_view node_name) const#

A common function to generate a prefix string for use in fail_check during the verify function.

Parameters:

node_name – If empty, the returned string will not include the node name.

Returns:

std::string The prefix string.

inline std::string VerifyFailPrefix(const utils::String &node_name) const#
void VerifyInputNum(int input_num, std::string_view node_name = "") const#

Verifies if the input number matches the pattern specified in the schema.

Parameters:
  • input_num – The number of inputs to be verified against the schema.

  • node_name – The prefix string used if the check fails.

inline void VerifyInputNum(int input_num, const utils::String &node_name) const#
void VerifyOutputNum(int output_num, std::string_view node_name = "") const#

Verifies if the output number matches the pattern specified in the schema.

Parameters:
  • output_num – The number of outputs to be verified against the schema.

  • node_name – The prefix string used if the check fails.

inline void VerifyOutputNum(int output_num, const utils::String &node_name) const#

Private Members

std::string name_#
std::string file_#
std::string doc_#
std::string domain_ = ONNX_DOMAIN#
std::unordered_map<std::string, Attribute> attributes_#
bool allows_unchecked_attributes_ = false#
std::vector<FormalParameter> inputs_#
std::vector<FormalParameter> outputs_#
std::vector<TypeConstraintParam> type_constraint_params_#
TypeConstraintMap type_constraints_#
int line_ = 0#
SupportType support_#
int min_input_ = 0#
int max_input_ = 0#
int min_output_ = 0#
int max_output_ = 0#
OperatorSetVersion since_version_ = kUninitializedSinceVersion#
bool deprecated_ = {}#
std::function<bool(int)> num_inputs_allowed_#
std::function<bool(int)> num_outputs_allowed_#
InferenceFunction tensor_inference_function_#
DataPropagationFunction data_propagation_function_#
std::map<int, std::shared_ptr<FunctionProto>> opset_version_to_function_body_#
std::map<int, ContextDependentFunctionBodyBuilder> opset_version_to_function_builder_#
NodeDeterminism node_determinism_ = NodeDeterminism::Unknown#

Friends

friend std::ostream &operator<<(std::ostream &out, const OpSchema &schema)#
struct Attribute#
#include <schema.h>

Public Functions

inline Attribute(std::string name_, std::string description_, AttributeProto::AttributeType type_, bool required_)#
inline Attribute(std::string name_, std::string description_, AttributeProto default_value_)#

Public Members

const std::string name#
const std::string description#
AttributeProto::AttributeType type#
bool required#
AttributeProto default_value#
class FormalParameter#
#include <schema.h>

Public Functions

FormalParameter() = default#
inline explicit FormalParameter(std::string name, DataTypeSet allowed_type_set, std::string type_str, std::string description, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
inline explicit FormalParameter(std::string name, std::string description, std::string type_str, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
const std::string &GetName() const#
const DataTypeSet &GetTypes() const#
const std::string &GetTypeStr() const#
const std::string &GetDescription() const#
FormalParameterOption GetOption() const#
bool GetIsHomogeneous() const#
int GetMinArity() const#
DifferentiationCategory GetDifferentiationCategory() const#

Private Functions

DataTypeSet &MutableTypes()#

Private Members

std::string name_#
DataTypeSet type_set_#
std::string type_str_#
std::string description_#
FormalParameterOption param_option_ = {}#
bool is_homogeneous_ = {}#
int min_arity_ = {}#
DifferentiationCategory differentiation_category_ = {}#

Friends

friend class OpSchema
struct TypeConstraintParam#
#include <schema.h>

Public Functions

inline TypeConstraintParam(std::string type_param_str_, std::vector<std::string> allowed_type_strs_, std::string description_)#

Public Members

std::string type_param_str#
std::vector<std::string> allowed_type_strs#
std::string description#
class OpSchemaRegistry : public ONNX_LIGHT_NAMESPACE::ISchemaRegistry#
#include <schema.h>

A registry to hold all the operator schemas.

Public Functions

virtual const OpSchema *GetSchema(const std::string &key, const int maxInclusiveVersion, const std::string &domain = ONNX_DOMAIN) const override#
virtual const OpSchema *GetSchema(const utils::String &key, const int maxInclusiveVersion, const utils::String &domain) const override#

Public Static Functions

static void OpSchemaDeregister(const std::string &op_type, const int version, const std::string &domain = ONNX_DOMAIN)#
static void OpSchemaDeregisterAll(const std::string &domain = ONNX_DOMAIN)#
static const OpSchema *Schema(const std::string &key, const std::string &domain = ONNX_DOMAIN)#
static const OpSchema *Schema(const utils::String &key, const utils::String &domain)#
static const OpSchema *Schema(const std::string &key, const int maxInclusiveVersion, const std::string &domain = ONNX_DOMAIN)#
static const OpSchema *Schema(const utils::String &key, const int maxInclusiveVersion, const utils::String &domain)#
static OpSchemaRegistry *Instance()#
static void SetLoadedSchemaVersion(int target_version)#
static int GetLoadedSchemaVersion()#
static std::vector<OpSchema> get_all_schemas_with_history()#
static std::vector<OpSchema> get_all_schemas()#

Private Functions

OpSchemaRegistry() = default#

Private Static Functions

static OpName_Domain_Version_Schema_Map &GetMapWithoutEnsuringRegistration()#

Returns the underlying string to OpSchema map.

You should not manually manipulate the map object returned. Instead, use the macros defined such as ONNX_OPERATOR_SET_SCHEMA to register your operator schema.

We wrap it inside a function to avoid the static initialization order fiasco.

With the change in function visibility, the GetMapWithoutEnsuringRegistration() and map() methods cannot be used to access the schema map directly from outside the OpSchemaRegistry class. Hence the ONNX_API macro is used to ensure that the methods are accessible from other translation units providing backward compatibility.

static OpName_Domain_Version_Schema_Map &map()#

Private Static Attributes

static int loaded_schema_version#
class DomainToVersionRange#
#include <schema.h>

Public Functions

DomainToVersionRange()#
const std::unordered_map<std::string, std::pair<int, int>> &Map() const#
const std::unordered_map<std::string, int> &LastReleaseVersionMap() const#
void AddDomainToVersion(const std::string &domain, int min_version, int max_version, int last_release_version = -1)#
void UpdateDomainToVersion(const std::string &domain, int min_version, int max_version, int last_release_version = -1)#

Public Static Functions

static DomainToVersionRange &Instance()#

Private Members

std::unordered_map<std::string, std::pair<int, int>> map_#
std::unordered_map<std::string, int> last_release_version_map_#
std::mutex mutex_#
class OpSchemaRegisterOnce#
#include <schema.h>

Public Functions

inline OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true)#

Public Static Functions

static void OpSchemaRegisterNoExcept(OpSchema &&op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true)#
static void OpSchemaRegisterImpl(OpSchema &&op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true)#

Private Static Functions

static int GetMaxRegisteredVerWithinTarget(const std::map<OperatorSetVersion, OpSchema> &m, int target_ver)#
static void CheckDomainAndVersionToRegister(const OpSchema &op_schema, const std::string &op_name, const std::string &op_domain)#
class SchemaError : public std::runtime_error#
#include <schema.h>

Public Functions

inline explicit SchemaError(const std::string &message)#
inline const char *what() const noexcept override#
inline void AppendContext(const std::string &context)#

Private Members

std::string expanded_message_#