Source code for onnx.shape_inference

# SPDX-License-Identifier: Apache-2.0
"""onnx shape inference. Shape inference is not guaranteed to be
complete.

"""

from typing import Dict, Optional, Union

import onnx
import onnx.onnx_cpp2py_export.shape_inference as C
from onnx import ModelProto


[docs]def infer_shapes( model: Union[ModelProto, bytes], check_type: bool = False, strict_mode: bool = False, data_prop: bool = False, ) -> ModelProto: """Apply shape inference to the provided ModelProto. Inferred shapes are added to the value_info field of the graph. If the inferred values conflict with values already provided in the graph, that means that the provided values are invalid (or there is a bug in shape inference), and the result is unspecified. Arguments: model (Union[ModelProto, bytes], bool, bool, bool) -> ModelProto check_type (bool): Checks the type-equality for input and output strict_mode (bool): Stricter shape inference, it will throw errors if any; Otherwise, simply stop if any error data_prop (bool): Enables data propagation for limited operators to perform shape computation Returns: (ModelProto) model with inferred shape information """ if isinstance(model, (ModelProto, bytes)): model_str = model if isinstance(model, bytes) else model.SerializeToString() inferred_model_str = C.infer_shapes( model_str, check_type, strict_mode, data_prop ) return onnx.load_from_string(inferred_model_str) elif isinstance(model, str): raise TypeError( "infer_shapes only accepts ModelProto or bytes," "you can use infer_shapes_path for the model path (String)." ) else: raise TypeError( "infer_shapes only accepts ModelProto or bytes, " "incorrect type: {}".format(type(model)) )
[docs]def infer_shapes_path( model_path: str, output_path: str = "", check_type: bool = False, strict_mode: bool = False, data_prop: bool = False, ) -> None: """ Take model path for shape_inference same as infer_shape; it support >2GB models Directly output the inferred model to the output_path; Default is the original model path """ if isinstance(model_path, ModelProto): raise TypeError( "infer_shapes_path only accepts model Path (String)," "you can use infer_shapes for the ModelProto." ) # Directly output the inferred model into the specified path, return nothing elif isinstance(model_path, str): # If output_path is not defined, default output_path would be the original model path if output_path == "": output_path = model_path C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop) else: raise TypeError( "infer_shapes_path only accepts model path (String), " "incorrect type: {}".format(type(model_path)) )
def infer_node_outputs( schema: onnx.defs.OpSchema, node: onnx.NodeProto, input_types: Dict[str, onnx.TypeProto], input_data: Optional[Dict[str, onnx.TensorProto]] = None, input_sparse_data: Optional[Dict[str, onnx.SparseTensorProto]] = None, ) -> Dict[str, onnx.TypeProto]: if not schema.has_type_and_shape_inference_function: # type: ignore return {} if input_data is None: input_data = {} if input_sparse_data is None: input_sparse_data = {} # To avoid copying on C++ side, pass only what is needed for this inference call passed_input_types = { key: input_types[key].SerializeToString() for key in node.input } passed_input_data = { key: input_data[key].SerializeToString() for key in node.input if key in input_data } passed_sparse_input_data = { key: input_sparse_data[key].SerializeToString() for key in node.input if key in input_sparse_data } outputs = schema._infer_node_outputs( node.SerializeToString(), passed_input_types, passed_input_data, passed_sparse_input_data, ) return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()} InferenceError = C.InferenceError