schema.h#
Operator schema definitions and registration utilities, including
onnx::OpSchema, onnx::OpSchemaRegistry, and
onnx::OpSchemaRegistry::DomainToVersionRange.
Defines
-
fail_schema(...)#
-
ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName)#
-
ONNX_OPERATOR_SET_SCHEMA(name, ver, impl)#
-
ONNX_ML_OPERATOR_SET_SCHEMA(name, ver, impl)#
-
ONNX_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl)#
-
ONNX_PREVIEW_OPERATOR_SET_SCHEMA(name, ver, impl)#
-
ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl)#
-
ONNX_DBG_INCREMENT_COUNT_IN_OPSETS()#
-
ONNX_OPERATOR_SET_SCHEMA_DEBUG_VARIABLE(name, domain, ver, dbg_included_in_static_opset)#
-
ONNX_OPERATOR_SET_SCHEMA_EX(name, domain, domain_str, ver, dbg_included_in_static_opset, impl)#
-
ONNX_DBG_GET_COUNT_IN_OPSETS()#
-
ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name)#
-
ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(ver, name)#
-
ONNX_OPERATOR_SCHEMA(name)#
-
ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name)#
-
ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)#
-
GET_OP_DOC_STR(doc_str)#
-
POPULATE_OP_DOC_STR(DocPopulatorCode)#
-
namespace ONNX_LIGHT_NAMESPACE
Typedefs
-
using FunctionBodyQueryFunction = std::function<bool(FunctionBodyBuildContext&)>#
-
using ContextDependentFunctionBodyBuilder = std::function<bool(const FunctionBodyBuildContext&, const OpSchema&, FunctionProto&)>#
-
using OperatorSetVersion = int#
Functions
-
void RegisterAllOnnxOperatorSchemas()#
-
void RegisterSchema(const OpSchema &schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true, bool fail_with_exception = false)#
-
void RegisterSchema(OpSchema &&schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true, bool fail_with_exception = false)#
-
template<class T>
void RegisterOpSetSchema(int opset_version_to_load = 0, bool fail_duplicate_schema = true)#
-
class DbgOperatorSetTracker#
- #include <schema.h>
Public Static Functions
-
static DbgOperatorSetTracker &Instance()#
Private Members
-
size_t count_ = 0#
-
static DbgOperatorSetTracker &Instance()#
-
struct FunctionBodyBuildContext#
- #include <schema.h>
Subclassed by ONNX_LIGHT_NAMESPACE::FunctionBodyBuildContextImpl
-
struct FunctionBodyBuildContextImpl : public ONNX_LIGHT_NAMESPACE::FunctionBodyBuildContext#
- #include <schema.h>
Public Functions
-
inline explicit FunctionBodyBuildContextImpl(const NodeProto &node_proto, const std::vector<TypeProto> &input_types = {})#
-
inline virtual const AttributeProto *getAttribute(const std::string &name) const override#
-
inline virtual bool hasInput(int inputIndex) const override#
-
inline virtual bool hasOutput(int inputIndex) const override#
-
inline explicit FunctionBodyBuildContextImpl(const NodeProto &node_proto, const std::vector<TypeProto> &input_types = {})#
-
class ISchemaRegistry#
- #include <schema.h>
Subclassed by ONNX_LIGHT_NAMESPACE::OpSchemaRegistry
Public Functions
-
virtual ~ISchemaRegistry() = default#
-
virtual const OpSchema *GetSchema(const std::string &key, const int maxInclusiveVersion, const std::string &domain = ONNX_DOMAIN) const = 0#
-
virtual ~ISchemaRegistry() = default#
-
class OpSchema#
- #include <schema.h>
A class to record the schema of an op.
OpSchema records the common interface of an op specified by its name.
To register an OpSchema, one can use the macro ONNX_OPERATOR_SCHEMA(name) and then append the various functions in the class. For example, for an op that takes in two inputs, one output, and the first input and output could be in-place, can be written as
To manufacture methods that may be used to register an OpSchema non-statically, the following may be used:ONNX_OPERATOR_SCHEMA(name) .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}});
ONNX_OPERATOR_SET_SCHEMA(name, version, OpSchema() .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}}));
Public Types
-
enum DifferentiationCategory#
Values:
-
enumerator Unknown#
-
enumerator Differentiable#
-
enumerator NonDifferentiable#
-
enumerator Unknown#
Public Functions
-
inline OpSchema()#
-
inline int line() const#
Returns the line in file that the op schema is registered from.
-
inline SupportType support_level() const#
Returns the support level of the op schema.
-
inline const char *doc() const#
Returns the docstring of the op schema.
-
void CheckInputOutputType(struct InferenceContext&) const#
-
void Verify(const NodeProto &node) const#
Verifies if a NodeProto matches the pattern specified in the schema.
-
OpSchema &SinceVersion(OperatorSetVersion n)#
The earliest operator set version which this operator was present in. If an operator has had no BC-breaking changes, this is simply the first operator set the operator was a member of; if it has had BC-breaking changes, then for the semantics /as described/ in the OpSchema entry, this version describes the operator set which introduced the BC-breaking change.
For example, suppose op Foo was added in v3, and had a BC-breaking change in v6. Then there will be an op schema entry for Foo with SinceVersion(3), and another, updated op schema entry for Foo with SinceVersion(6).
-
OpSchema &Deprecate()#
Marks this op as deprecated as of it’s since_version. This will cause the Schema() lookup functions to return nullptr when the version is in the deprecated range.
-
inline bool Deprecated() const#
-
OpSchema &NumInputs(std::unordered_set<int> allowed_input_nums)#
Input could be one of the values specified in allowed_input_nums.
-
OpSchema &NumOutputs(std::unordered_set<int> allowed_output_nums)#
Output could be one of the values specified in allowed_output_nums.
-
OpSchema &TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction)#
-
InferenceFunction GetTypeAndShapeInferenceFunction() const#
-
OpSchema &PartialDataPropagationFunction(DataPropagationFunction dataPropagationFunction)#
-
inline DataPropagationFunction GetDataPropagationFunction() const#
-
OpSchema &SetSupportLevel(SupportType supportType)#
-
OpSchema &Attr(const char *name, const char *description, AttributeProto::AttributeType type, const char *defaultValue)#
-
OpSchema &Attr(std::string name, std::string description, std::string conditionExplanation, AttributeProto::AttributeType attr_type)#
-
OpSchema &Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required = true)#
-
OpSchema &Attr(const char *name, const char *description, AttributeProto::AttributeType type, bool required = true)#
-
OpSchema &Input(int n, FormalParameter formal_parameter)#
-
OpSchema &Input(int n, std::string name, const std::string &description, std::string type_str, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
-
OpSchema &Input(int n, const char *name, const char *description, const char *type_str, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
-
OpSchema &Output(int n, FormalParameter formal_parameter)#
-
OpSchema &Output(int n, std::string name, const std::string &description, std::string type_str, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
-
OpSchema &Output(int n, const char *name, const char *description, const char *type_str, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
-
OpSchema &TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description)#
-
OpSchema &TypeConstraint(const char *type_str, std::initializer_list<const char*> constraints, const char *description)#
-
inline const std::vector<FormalParameter> &inputs() const#
-
inline const std::vector<FormalParameter> &outputs() const#
-
inline const std::vector<TypeConstraintParam> &typeConstraintParams() const#
-
inline const TypeConstraintMap &typeConstraintMap() const#
-
inline OperatorSetVersion SinceVersion() const#
-
inline int since_version() const#
-
inline bool deprecated() const#
-
inline int min_input() const#
-
inline int max_input() const#
-
inline int min_output() const#
-
inline int max_output() const#
-
inline bool has_type_and_shape_inference_function() const#
-
inline bool has_data_propagation_function() const#
-
inline bool HasFunction() const#
-
OpSchema &FunctionBody(const std::vector<NodeProto> &func_nodes, int opset_version = kUninitializedSinceVersion)#
-
OpSchema &FunctionBody(const std::vector<NodeProto> &func_nodes, const std::vector<OperatorSetIdProto> &opsets, int opset_version = kUninitializedSinceVersion)#
-
OpSchema &FunctionBody(const char *func_body, int opset_version = kUninitializedSinceVersion)#
-
const FunctionProto *GetFunction(int requested_opset_version = OpSchema::kUninitializedSinceVersion, bool validate = false) const#
-
bool HasContextDependentFunction() const#
-
bool HasContextDependentFunctionWithOpsetVersion(int opset_version) const#
-
OpSchema &SetContextDependentFunctionBodyBuilder(ContextDependentFunctionBodyBuilder, int opset_version = kUninitializedSinceVersion)#
-
bool BuildContextDependentFunction(const FunctionBodyBuildContext &ctx, FunctionProto &function_proto, int requested_opset_version = OpSchema::kUninitializedSinceVersion) const#
-
void Finalize()#
-
void BuildFunction(FunctionProto &function_body) const#
-
NodeDeterminism GetNodeDeterminism() const#
-
OpSchema &SetNodeDeterminism(NodeDeterminism node_determinism)#
Public Static Functions
Public Static Attributes
-
static constexpr int kUninitializedSinceVersion = -1#
Private Functions
-
void ParseAndSetTypes(std::vector<OpSchema::FormalParameter> *formal_parameters)#
-
bool ValidateReferencedOpsInFunction(const FunctionProto *function, int requested_opset_version, int function_since_version, std::unordered_set<std::string> *updated_ops = nullptr) const#
-
void UpdateFunctionProtoOpsetImportVersion(FunctionProto &function_proto, int opset_version) const#
-
std::string VerifyFailPrefix(std::string_view node_name) const#
A common function to generate a prefix string for use in fail_check during the verify function.
- Parameters:
node_name – If empty, the returned string will not include the node name.
- Returns:
std::string The prefix string.
-
void VerifyInputNum(int input_num, std::string_view node_name = "") const#
Verifies if the input number matches the pattern specified in the schema.
- Parameters:
input_num – The number of inputs to be verified against the schema.
node_name – The prefix string used if the check fails.
Private Members
-
std::string domain_ = ONNX_DOMAIN#
-
bool allows_unchecked_attributes_ = false#
-
std::vector<FormalParameter> inputs_#
-
std::vector<FormalParameter> outputs_#
-
std::vector<TypeConstraintParam> type_constraint_params_#
-
TypeConstraintMap type_constraints_#
-
int line_ = 0#
-
SupportType support_#
-
int min_input_ = 0#
-
int max_input_ = 0#
-
int min_output_ = 0#
-
int max_output_ = 0#
-
OperatorSetVersion since_version_ = kUninitializedSinceVersion#
-
bool deprecated_ = {}#
-
InferenceFunction tensor_inference_function_#
-
DataPropagationFunction data_propagation_function_#
-
std::map<int, std::shared_ptr<FunctionProto>> opset_version_to_function_body_#
-
std::map<int, ContextDependentFunctionBodyBuilder> opset_version_to_function_builder_#
-
NodeDeterminism node_determinism_ = NodeDeterminism::Unknown#
-
struct Attribute#
- #include <schema.h>
Public Functions
-
inline Attribute(std::string name_, std::string description_, AttributeProto::AttributeType type_, bool required_)#
-
inline Attribute(std::string name_, std::string description_, AttributeProto default_value_)#
-
inline Attribute(std::string name_, std::string description_, AttributeProto::AttributeType type_, bool required_)#
-
class FormalParameter#
- #include <schema.h>
Public Functions
-
FormalParameter() = default#
-
inline explicit FormalParameter(std::string name, DataTypeSet allowed_type_set, std::string type_str, std::string description, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
-
inline explicit FormalParameter(std::string name, std::string description, std::string type_str, FormalParameterOption param_option = Single, bool is_homogeneous = true, int min_arity = 1, DifferentiationCategory differentiation_category = Unknown)#
-
const DataTypeSet &GetTypes() const#
-
FormalParameterOption GetOption() const#
-
bool GetIsHomogeneous() const#
-
int GetMinArity() const#
-
DifferentiationCategory GetDifferentiationCategory() const#
Private Functions
-
DataTypeSet &MutableTypes()#
Private Members
-
DataTypeSet type_set_#
-
FormalParameterOption param_option_ = {}#
-
bool is_homogeneous_ = {}#
-
int min_arity_ = {}#
-
DifferentiationCategory differentiation_category_ = {}#
Friends
- friend class OpSchema
-
FormalParameter() = default#
-
enum DifferentiationCategory#
-
class OpSchemaRegistry : public ONNX_LIGHT_NAMESPACE::ISchemaRegistry#
- #include <schema.h>
A registry to hold all the operator schemas.
Public Functions
-
virtual const OpSchema *GetSchema(const std::string &key, const int maxInclusiveVersion, const std::string &domain = ONNX_DOMAIN) const override#
Public Static Functions
-
static void OpSchemaDeregister(const std::string &op_type, const int version, const std::string &domain = ONNX_DOMAIN)#
-
static void OpSchemaDeregisterAll(const std::string &domain = ONNX_DOMAIN)#
-
static const OpSchema *Schema(const std::string &key, const std::string &domain = ONNX_DOMAIN)#
-
static const OpSchema *Schema(const std::string &key, const int maxInclusiveVersion, const std::string &domain = ONNX_DOMAIN)#
-
static const OpSchema *Schema(const utils::String &key, const int maxInclusiveVersion, const utils::String &domain)#
-
static OpSchemaRegistry *Instance()#
-
static void SetLoadedSchemaVersion(int target_version)#
-
static int GetLoadedSchemaVersion()#
Private Functions
-
OpSchemaRegistry() = default#
Private Static Functions
-
static OpName_Domain_Version_Schema_Map &GetMapWithoutEnsuringRegistration()#
Returns the underlying string to OpSchema map.
You should not manually manipulate the map object returned. Instead, use the macros defined such as ONNX_OPERATOR_SET_SCHEMA to register your operator schema.
We wrap it inside a function to avoid the static initialization order fiasco.
With the change in function visibility, the GetMapWithoutEnsuringRegistration() and map() methods cannot be used to access the schema map directly from outside the OpSchemaRegistry class. Hence the ONNX_API macro is used to ensure that the methods are accessible from other translation units providing backward compatibility.
-
static OpName_Domain_Version_Schema_Map &map()#
Private Static Attributes
-
static int loaded_schema_version#
-
class DomainToVersionRange#
- #include <schema.h>
Public Functions
-
DomainToVersionRange()#
Public Static Functions
-
static DomainToVersionRange &Instance()#
-
DomainToVersionRange()#
-
class OpSchemaRegisterOnce#
- #include <schema.h>
Public Functions
-
virtual const OpSchema *GetSchema(const std::string &key, const int maxInclusiveVersion, const std::string &domain = ONNX_DOMAIN) const override#
-
using FunctionBodyQueryFunction = std::function<bool(FunctionBodyBuildContext&)>#