# SPDX-License-Identifier: Apache-2.0
"""onnx checker
This implements graphalities that allows us to check whether a serialized
proto is legal.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import functools
from onnx import (ValueInfoProto,
AttributeProto,
TensorProto,
SparseTensorProto,
NodeProto,
ModelProto,
GraphProto,
IR_VERSION)
import onnx.onnx_cpp2py_export.checker as C
import onnx.defs
from google.protobuf.message import Message
from typing import TypeVar, Callable, Any, Type, cast, Union, Text
import onnx.shape_inference
import sys
# Limitation of single protobuf file is 2GB
MAXIMUM_PROTOBUF = 2000000000
# TODO: This thing where we reserialize the protobuf back into the
# string, only to deserialize it at the call site, is really goofy.
# Stop doing that.
# NB: Please don't edit this context!
DEFAULT_CONTEXT = C.CheckerContext()
DEFAULT_CONTEXT.ir_version = IR_VERSION
# TODO: Maybe ONNX-ML should also be defaulted?
DEFAULT_CONTEXT.opset_imports = {'': onnx.defs.onnx_opset_version()}
FuncType = TypeVar('FuncType', bound=Callable[..., Any])
# TODO: This really doesn't seem worth the metaprogramming...
def _create_checker(proto_type: Type[Message]) -> Callable[[FuncType], FuncType]:
def decorator(py_func: FuncType) -> FuncType:
@functools.wraps(py_func)
def checker(proto: Message, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> Any:
if not isinstance(proto, proto_type):
raise RuntimeError(
'You cannot pass an object that is not of type {}'.format(
proto_type.__name__))
return getattr(C, py_func.__name__)(
proto.SerializeToString(), ctx)
return cast(FuncType, checker)
return decorator
@_create_checker(ValueInfoProto)
def check_value_info(value_info: ValueInfoProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
pass
@_create_checker(TensorProto)
def check_tensor(tensor: TensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
pass
@_create_checker(AttributeProto)
def check_attribute(attr: AttributeProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
pass
@_create_checker(NodeProto)
def check_node(node: NodeProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
pass
@_create_checker(GraphProto)
def check_graph(graph: GraphProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
pass
def check_sparse_tensor(sparse: SparseTensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
C.check_sparse_tensor(sparse.SerializeToString(), ctx)
[docs]def check_model(model: Union[ModelProto, Text, bytes], full_check: bool = False) -> None:
# If model is a path instead of ModelProto
if isinstance(model, str):
C.check_model_path(model)
if full_check:
onnx.shape_inference.infer_shapes_path(model, check_type=True, strict_mode=True)
else:
protobuf_string = model if isinstance(model, bytes) else model.SerializeToString()
# If the protobuf is larger than 2GB,
# remind users should use the model path to check
if sys.getsizeof(protobuf_string) > MAXIMUM_PROTOBUF:
raise ValueError('This protobuf of onnx model is too large (>2GB). Call check_model with model path instead.')
C.check_model(protobuf_string)
if full_check:
onnx.shape_inference.infer_shapes(model, check_type=True, strict_mode=True)
ValidationError = C.ValidationError