Source code for onnx.checker
# SPDX-License-Identifier: Apache-2.0
"""onnx checker
This implements graphalities that allows us to check whether a serialized
proto is legal.
"""
import functools
import sys
from typing import Any, Callable, Type, TypeVar, Union, cast
from google.protobuf.message import Message
import onnx.defs
import onnx.onnx_cpp2py_export.checker as C
import onnx.shape_inference
from onnx import (
IR_VERSION,
AttributeProto,
GraphProto,
ModelProto,
NodeProto,
SparseTensorProto,
TensorProto,
ValueInfoProto,
)
# Limitation of single protobuf file is 2GB
MAXIMUM_PROTOBUF = 2000000000
# TODO: This thing where we reserialize the protobuf back into the
# string, only to deserialize it at the call site, is really goofy.
# Stop doing that.
# NB: Please don't edit this context!
DEFAULT_CONTEXT = C.CheckerContext()
DEFAULT_CONTEXT.ir_version = IR_VERSION
# TODO: Maybe ONNX-ML should also be defaulted?
DEFAULT_CONTEXT.opset_imports = {"": onnx.defs.onnx_opset_version()}
FuncType = TypeVar("FuncType", bound=Callable[..., Any])
# TODO: This really doesn't seem worth the metaprogramming...
def _create_checker(proto_type: Type[Message]) -> Callable[[FuncType], FuncType]:
def decorator(py_func: FuncType) -> FuncType:
@functools.wraps(py_func)
def checker(proto: Message, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> Any:
if not isinstance(proto, proto_type):
raise RuntimeError(
"You cannot pass an object that is not of type {}".format(
proto_type.__name__
)
)
return getattr(C, py_func.__name__)(proto.SerializeToString(), ctx)
return cast(FuncType, checker)
return decorator
@_create_checker(ValueInfoProto)
def check_value_info(
value_info: ValueInfoProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
) -> None:
pass
@_create_checker(TensorProto)
def check_tensor(tensor: TensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
pass
@_create_checker(AttributeProto)
def check_attribute(
attr: AttributeProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
) -> None:
pass
@_create_checker(NodeProto)
def check_node(node: NodeProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
pass
@_create_checker(GraphProto)
def check_graph(graph: GraphProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
pass
def check_sparse_tensor(
sparse: SparseTensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
) -> None:
C.check_sparse_tensor(sparse.SerializeToString(), ctx)
[docs]def check_model(model: Union[ModelProto, str, bytes], full_check: bool = False) -> None:
"""Check the consistency of a model. An exception is raised if the test fails.
Arguments:
model (ModelProto): model to check
full_check (bool): if True, the function checks shapes can be inferred
"""
# If model is a path instead of ModelProto
if isinstance(model, str):
C.check_model_path(model, full_check)
else:
protobuf_string = (
model if isinstance(model, bytes) else model.SerializeToString()
)
# If the protobuf is larger than 2GB,
# remind users should use the model path to check
if sys.getsizeof(protobuf_string) > MAXIMUM_PROTOBUF:
raise ValueError(
"This protobuf of onnx model is too large (>2GB). Call check_model with model path instead."
)
C.check_model(protobuf_string, full_check)
ValidationError = C.ValidationError