# SPDX-License-Identifier: Apache-2.0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import collections.abc # type: ignore
import numbers
import google.protobuf.message
from onnx import TensorProto, SparseTensorProto, AttributeProto, ValueInfoProto, \
TensorShapeProto, NodeProto, ModelProto, GraphProto, OperatorSetIdProto, \
TypeProto, SequenceProto, MapProto, IR_VERSION, TrainingInfoProto, OptionalProto, \
FunctionProto
from onnx import defs
from onnx import mapping
from onnx.mapping import STORAGE_TENSOR_TYPE_TO_FIELD
from typing import Text, Sequence, Any, Optional, Dict, Union, TypeVar, Callable, Tuple, List, cast
import numpy as np # type: ignore
import warnings
VersionRowType = Union[Tuple[Text, int, int, int], Tuple[Text, int, int, int, int]]
VersionTableType = List[VersionRowType]
AssignmentBindingType = List[Tuple[Text, Text]]
# This is a copy of the documented version in https://github.com/onnx/onnx/blob/main/docs/Versioning.md#released-versions
# Both must be updated whenever a new version of ONNX is released.
VERSION_TABLE: VersionTableType = [
# Release-version, IR version, ai.onnx version, ai.onnx.ml version, (optional) ai.onnx.training version
('1.0', 3, 1, 1),
('1.1', 3, 5, 1),
('1.1.2', 3, 6, 1),
('1.2', 3, 7, 1),
('1.3', 3, 8, 1),
('1.4.1', 4, 9, 1),
('1.5.0', 5, 10, 1),
('1.6.0', 6, 11, 2),
('1.7.0', 7, 12, 2, 1),
('1.8.0', 7, 13, 2, 1),
('1.8.1', 7, 13, 2, 1),
('1.9.0', 7, 14, 2, 1),
('1.10.0', 8, 15, 2, 1),
('1.10.1', 8, 15, 2, 1),
('1.10.2', 8, 15, 2, 1),
('1.11.0', 8, 16, 3, 1)
]
VersionMapType = Dict[Tuple[Text, int], int]
# create a map from (opset-domain, opset-version) to ir-version from above table
def create_op_set_id_version_map(table: VersionTableType) -> VersionMapType:
result: VersionMapType = dict()
def process(release_version: Text, ir_version: int, *args: Any) -> None:
for pair in zip(['ai.onnx', 'ai.onnx.ml', 'ai.onnx.training'], args):
if (pair not in result):
result[pair] = ir_version
for row in table:
process(*row)
return result
OP_SET_ID_VERSION_MAP = create_op_set_id_version_map(VERSION_TABLE)
# Given list of opset ids, determine minimum IR version required
[docs]def find_min_ir_version_for(opsetidlist: List[OperatorSetIdProto]) -> int:
default_min_version = 3
def find_min(domain: Union[Text, None], version: int) -> int:
key = (domain if domain else 'ai.onnx', version)
if (key in OP_SET_ID_VERSION_MAP):
return OP_SET_ID_VERSION_MAP[key]
else:
raise ValueError("Unsupported opset-version.")
if (opsetidlist):
return max([find_min(x.domain, x.version) for x in opsetidlist])
return default_min_version # if no opsets specified
[docs]def make_node(
op_type: Text,
inputs: Sequence[Text],
outputs: Sequence[Text],
name: Optional[Text] = None,
doc_string: Optional[Text] = None,
domain: Optional[Text] = None,
**kwargs: Any
) -> NodeProto:
"""Construct a NodeProto.
Arguments:
op_type (string): The name of the operator to construct
inputs (list of string): list of input names
outputs (list of string): list of output names
name (string, default None): optional unique identifier for NodeProto
doc_string (string, default None): optional documentation string for NodeProto
domain (string, default None): optional domain for NodeProto.
If it's None, we will just use default domain (which is empty)
**kwargs (dict): the attributes of the node. The acceptable values
are documented in :func:`make_attribute`.
"""
node = NodeProto()
node.op_type = op_type
node.input.extend(inputs)
node.output.extend(outputs)
if name:
node.name = name
if doc_string:
node.doc_string = doc_string
if domain is not None:
node.domain = domain
if kwargs:
node.attribute.extend(
make_attribute(key, value)
for key, value in sorted(kwargs.items())
if value is not None)
return node
[docs]def make_operatorsetid(
domain: Text,
version: int,
) -> OperatorSetIdProto:
"""Construct an OperatorSetIdProto.
Arguments:
domain (string): The domain of the operator set id
version (integer): Version of operator set id
"""
operatorsetid = OperatorSetIdProto()
operatorsetid.domain = domain
operatorsetid.version = version
return operatorsetid
[docs]def make_graph(
nodes: Sequence[NodeProto],
name: Text,
inputs: Sequence[ValueInfoProto],
outputs: Sequence[ValueInfoProto],
initializer: Optional[Sequence[TensorProto]] = None,
doc_string: Optional[Text] = None,
value_info: Sequence[ValueInfoProto] = [],
sparse_initializer: Optional[Sequence[SparseTensorProto]] = None,
) -> GraphProto:
if initializer is None:
initializer = []
if sparse_initializer is None:
sparse_initializer = []
if value_info is None:
value_info = []
graph = GraphProto()
graph.node.extend(nodes)
graph.name = name
graph.input.extend(inputs)
graph.output.extend(outputs)
graph.initializer.extend(initializer)
graph.sparse_initializer.extend(sparse_initializer)
graph.value_info.extend(value_info)
if doc_string:
graph.doc_string = doc_string
return graph
[docs]def make_opsetid(domain: Text, version: int) -> OperatorSetIdProto:
opsetid = OperatorSetIdProto()
opsetid.domain = domain
opsetid.version = version
return opsetid
def make_function(
domain: Text,
fname: Text,
inputs: Sequence[Text],
outputs: Sequence[Text],
nodes: Sequence[NodeProto],
opset_imports: Sequence[OperatorSetIdProto],
attributes: Optional[Sequence[Text]] = [],
doc_string: Optional[Text] = None
) -> FunctionProto:
f = FunctionProto()
f.domain = domain
f.name = fname
f.input.extend(inputs)
f.output.extend(outputs)
f.node.extend(nodes)
f.opset_import.extend(opset_imports)
f.attribute.extend(attributes)
if doc_string:
f.doc_string = doc_string
return f
[docs]def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto:
model = ModelProto()
# Touch model.ir_version so it is stored as the version from which it is
# generated.
model.ir_version = IR_VERSION
model.graph.CopyFrom(graph)
opset_imports: Optional[Sequence[OperatorSetIdProto]] = None
opset_imports = kwargs.pop('opset_imports', None) # type: ignore
if opset_imports is not None:
model.opset_import.extend(opset_imports)
else:
# Default import
imp = model.opset_import.add()
imp.version = defs.onnx_opset_version()
functions: Optional[Sequence[FunctionProto]] = None
functions = kwargs.pop('functions', None) # type: ignore
if functions is not None:
model.functions.extend(functions)
for k, v in kwargs.items():
# TODO: Does this work with repeated fields?
setattr(model, k, v)
return model
# An extension of make_model that infers an IR_VERSION for the model,
# if not specified, using a best-effort-basis.
def make_model_gen_version(graph: GraphProto, **kwargs: Any) -> ModelProto:
ir_version_field = str('ir_version')
if (ir_version_field not in kwargs):
opset_imports_field = str('opset_imports')
imports = (kwargs[opset_imports_field] if opset_imports_field in kwargs else [])
kwargs[ir_version_field] = find_min_ir_version_for(imports)
return make_model(graph, **kwargs)
def set_model_props(model: ModelProto, dict_value: Dict[Text, Text]) -> None:
del model.metadata_props[:]
for (k, v) in dict_value.items():
entry = model.metadata_props.add()
entry.key = k
entry.value = v
# model.metadata_properties.append(entry)
[docs]def split_complex_to_pairs(ca: Sequence[np.complex64]) -> Sequence[int]:
return [(ca[i // 2].real if (i % 2 == 0) else ca[i // 2].imag)
for i in range(len(ca) * 2)]
[docs]def make_tensor(
name: Text,
data_type: int,
dims: Sequence[int],
vals: Any,
raw: bool = False
) -> TensorProto:
'''
Make a TensorProto with specified arguments. If raw is False, this
function will choose the corresponding proto field to store the
values based on data_type. If raw is True, use "raw_data" proto
field to store the values, and values should be of type bytes in
this case.
'''
tensor = TensorProto()
tensor.data_type = data_type
tensor.name = name
if data_type == TensorProto.STRING:
assert not raw, "Can not use raw_data to store string type"
# Check number of vals specified equals tensor size
expected_size = 1 if (not raw) else (mapping.TENSOR_TYPE_TO_NP_TYPE[data_type].itemsize)
# Flatten a numpy array if its rank > 1
if type(vals) is np.ndarray and len(vals.shape) > 1:
vals = vals.flatten()
for d in dims:
expected_size = expected_size * d
if len(vals) != expected_size:
raise ValueError("Number of values does not match tensor's size. Expected {}, but it is {}. "
.format(expected_size, len(vals)))
if raw:
tensor.raw_data = vals
else:
if (data_type == TensorProto.COMPLEX64
or data_type == TensorProto.COMPLEX128):
vals = split_complex_to_pairs(vals)
# floa16/bfloat16 are stored as uint16
elif (data_type == TensorProto.FLOAT16
or data_type == TensorProto.BFLOAT16):
vals = np.array(vals).astype(np.float16).view(dtype=np.uint16).flatten().tolist()
field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[
mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[data_type]]
getattr(tensor, field).extend(vals)
tensor.dims.extend(dims)
return tensor
[docs]def make_sparse_tensor(
values: TensorProto,
indices: TensorProto,
dims: Sequence[int]
) -> SparseTensorProto:
sparse = SparseTensorProto()
sparse.values.CopyFrom(values)
sparse.indices.CopyFrom(indices)
sparse.dims.extend(dims)
return sparse
[docs]def make_sequence(
name: Text,
elem_type: SequenceProto.DataType,
values: Sequence[Any],
) -> SequenceProto:
'''
Make a Sequence with specified value arguments.
'''
sequence = SequenceProto()
sequence.name = name
sequence.elem_type = elem_type
values_field = mapping.STORAGE_ELEMENT_TYPE_TO_FIELD[elem_type]
getattr(sequence, values_field).extend(values)
return sequence
[docs]def make_map(
name: Text,
key_type: int,
keys: List[Any],
values: SequenceProto
) -> MapProto:
'''
Make a Map with specified key-value pair arguments.
Criteria for conversion:
- Keys and Values must have the same number of elements
- Every key in keys must be of the same type
- Every value in values must be of the same type
'''
map = MapProto()
valid_key_int_types = [TensorProto.INT8, TensorProto.INT16, TensorProto.INT32,
TensorProto.INT64, TensorProto.UINT8, TensorProto.UINT16,
TensorProto.UINT32, TensorProto.UINT64]
map.name = name
map.key_type = key_type
if key_type == TensorProto.STRING:
map.string_keys.extend(keys)
elif key_type in valid_key_int_types:
map.keys.extend(keys)
map.values.CopyFrom(values)
return map
[docs]def make_optional(
name: Text,
elem_type: OptionalProto.DataType,
value: Optional[Any],
) -> OptionalProto:
'''
Make an Optional with specified value arguments.
'''
optional = OptionalProto()
optional.name = name
optional.elem_type = elem_type
if elem_type != 0:
values_field = mapping.OPTIONAL_ELEMENT_TYPE_TO_FIELD[elem_type]
getattr(optional, values_field).CopyFrom(value)
return optional
def _to_bytes_or_false(val: Union[Text, bytes]) -> Union[bytes, bool]:
"""An internal graph to convert the input to a bytes or to False.
The criteria for conversion is as follows and should be python 2 and 3
compatible:
- If val is py2 str or py3 bytes: return bytes
- If val is py2 unicode or py3 str: return val.decode('utf-8')
- Otherwise, return False
"""
if isinstance(val, bytes):
return val
try:
return val.encode('utf-8')
except AttributeError:
return False
[docs]def make_attribute(
key: Text,
value: Any,
doc_string: Optional[Text] = None
) -> AttributeProto:
"""Makes an AttributeProto based on the value type."""
attr = AttributeProto()
attr.name = key
if doc_string:
attr.doc_string = doc_string
is_iterable = isinstance(value, collections.abc.Iterable)
bytes_or_false = _to_bytes_or_false(value)
# First, singular cases
# float
if isinstance(value, float):
attr.f = value
attr.type = AttributeProto.FLOAT
# integer
elif isinstance(value, numbers.Integral):
attr.i = cast(int, value)
attr.type = AttributeProto.INT
# string
elif bytes_or_false is not False:
assert isinstance(bytes_or_false, bytes)
attr.s = bytes_or_false
attr.type = AttributeProto.STRING
elif isinstance(value, TensorProto):
attr.t.CopyFrom(value)
attr.type = AttributeProto.TENSOR
elif isinstance(value, SparseTensorProto):
attr.sparse_tensor.CopyFrom(value)
attr.type = AttributeProto.SPARSE_TENSOR
elif isinstance(value, GraphProto):
attr.g.CopyFrom(value)
attr.type = AttributeProto.GRAPH
elif isinstance(value, TypeProto):
attr.tp.CopyFrom(value)
attr.type = AttributeProto.TYPE_PROTO
# third, iterable cases
elif is_iterable:
byte_array = [_to_bytes_or_false(v) for v in value]
if all(isinstance(v, numbers.Integral) for v in value):
# Turn np.int32/64 into Python built-in int.
attr.ints.extend(int(v) for v in value)
attr.type = AttributeProto.INTS
elif all(isinstance(v, numbers.Real) for v in value):
# Since ints and floats are members of Real, this allows a mix of ints and floats
# (and converts the ints to floats).
attr.floats.extend(float(v) for v in value)
attr.type = AttributeProto.FLOATS
elif all(map(lambda bytes_or_false: bytes_or_false is not False, byte_array)):
attr.strings.extend(cast(List[bytes], byte_array))
attr.type = AttributeProto.STRINGS
elif all(isinstance(v, TensorProto) for v in value):
attr.tensors.extend(value)
attr.type = AttributeProto.TENSORS
elif all(isinstance(v, SparseTensorProto) for v in value):
attr.sparse_tensors.extend(value)
attr.type = AttributeProto.SPARSE_TENSORS
elif all(isinstance(v, GraphProto) for v in value):
attr.graphs.extend(value)
attr.type = AttributeProto.GRAPHS
elif all(isinstance(tp, TypeProto) for tp in value):
attr.type_protos.extend(value)
attr.type = AttributeProto.TYPE_PROTOS
else:
raise ValueError(
"You passed in an iterable attribute but I cannot figure out "
"its applicable type.")
else:
raise TypeError(
'value "{}" is not valid attribute data type.'.format(value))
return attr
[docs]def get_attribute_value(attr: AttributeProto) -> Any:
if attr.type == AttributeProto.FLOAT:
return attr.f
if attr.type == AttributeProto.INT:
return attr.i
if attr.type == AttributeProto.STRING:
return attr.s
if attr.type == AttributeProto.TENSOR:
return attr.t
if attr.type == AttributeProto.SPARSE_TENSOR:
return attr.sparse_tensor
if attr.type == AttributeProto.GRAPH:
return attr.g
if attr.type == AttributeProto.TYPE_PROTO:
return attr.tp
if attr.type == AttributeProto.FLOATS:
return list(attr.floats)
if attr.type == AttributeProto.INTS:
return list(attr.ints)
if attr.type == AttributeProto.STRINGS:
return list(attr.strings)
if attr.type == AttributeProto.TENSORS:
return list(attr.tensors)
if attr.type == AttributeProto.SPARSE_TENSORS:
return list(attr.sparse_tensors)
if attr.type == AttributeProto.GRAPHS:
return list(attr.graphs)
if attr.type == AttributeProto.TYPE_PROTOS:
return list(attr.type_protos)
raise ValueError("Unsupported ONNX attribute: {}".format(attr))
[docs]def make_empty_tensor_value_info(name: Text) -> ValueInfoProto:
value_info_proto = ValueInfoProto()
value_info_proto.name = name
return value_info_proto
[docs]def make_tensor_type_proto(
elem_type: int,
shape: Optional[Sequence[Union[Text, int, None]]],
shape_denotation: Optional[List[Text]] = None,
) -> TypeProto:
"""Makes a Tensor TypeProto based on the data type and shape."""
type_proto = TypeProto()
tensor_type_proto = type_proto.tensor_type
tensor_type_proto.elem_type = elem_type
tensor_shape_proto = tensor_type_proto.shape
if shape is not None:
# You might think this is a no-op (extending a normal Python
# list by [] certainly is), but protobuf lists work a little
# differently; if a field is never set, it is omitted from the
# resulting protobuf; a list that is explicitly set to be
# empty will get an (empty) entry in the protobuf. This
# difference is visible to our consumers, so make sure we emit
# an empty shape!
tensor_shape_proto.dim.extend([])
if shape_denotation:
if len(shape_denotation) != len(shape):
raise ValueError(
'Invalid shape_denotation. '
'Must be of the same length as shape.')
for i, d in enumerate(shape):
dim = tensor_shape_proto.dim.add()
if d is None:
pass
elif isinstance(d, int):
dim.dim_value = d
elif isinstance(d, str):
dim.dim_param = d
else:
raise ValueError(
'Invalid item in shape: {}. '
'Needs to be of int or str.'.format(d))
if shape_denotation:
dim.denotation = shape_denotation[i]
return type_proto
[docs]def make_tensor_value_info(
name: Text,
elem_type: int,
shape: Optional[Sequence[Union[Text, int, None]]],
doc_string: Text = "",
shape_denotation: Optional[List[Text]] = None,
) -> ValueInfoProto:
"""Makes a ValueInfoProto based on the data type and shape."""
value_info_proto = ValueInfoProto()
value_info_proto.name = name
if doc_string:
value_info_proto.doc_string = doc_string
tensor_type_proto = make_tensor_type_proto(elem_type, shape, shape_denotation)
value_info_proto.type.CopyFrom(tensor_type_proto)
return value_info_proto
[docs]def make_sparse_tensor_type_proto(
elem_type: int,
shape: Optional[Sequence[Union[Text, int, None]]],
shape_denotation: Optional[List[Text]] = None,
) -> TypeProto:
"""Makes a SparseTensor TypeProto based on the data type and shape."""
type_proto = TypeProto()
sparse_tensor_type_proto = type_proto.sparse_tensor_type
sparse_tensor_type_proto.elem_type = elem_type
sparse_tensor_shape_proto = sparse_tensor_type_proto.shape
if shape is not None:
# You might think this is a no-op (extending a normal Python
# list by [] certainly is), but protobuf lists work a little
# differently; if a field is never set, it is omitted from the
# resulting protobuf; a list that is explicitly set to be
# empty will get an (empty) entry in the protobuf. This
# difference is visible to our consumers, so make sure we emit
# an empty shape!
sparse_tensor_shape_proto.dim.extend([])
if shape_denotation:
if len(shape_denotation) != len(shape):
raise ValueError(
'Invalid shape_denotation. '
'Must be of the same length as shape.')
for i, d in enumerate(shape):
dim = sparse_tensor_shape_proto.dim.add()
if d is None:
pass
elif isinstance(d, int):
dim.dim_value = d
elif isinstance(d, str):
dim.dim_param = d
else:
raise ValueError(
'Invalid item in shape: {}. '
'Needs to be of int or text.'.format(d))
if shape_denotation:
dim.denotation = shape_denotation[i]
return type_proto
[docs]def make_sparse_tensor_value_info(
name: Text,
elem_type: int,
shape: Optional[Sequence[Union[Text, int, None]]],
doc_string: Text = "",
shape_denotation: Optional[List[Text]] = None,
) -> ValueInfoProto:
"""Makes a SparseTensor ValueInfoProto based on the data type and shape."""
value_info_proto = ValueInfoProto()
value_info_proto.name = name
if doc_string:
value_info_proto.doc_string = doc_string
sparse_tensor_type_proto = make_sparse_tensor_type_proto(elem_type, shape, shape_denotation)
value_info_proto.type.sparse_tensor_type.CopyFrom(sparse_tensor_type_proto.sparse_tensor_type)
return value_info_proto
[docs]def make_sequence_type_proto(
inner_type_proto: TypeProto,
) -> TypeProto:
"""Makes a sequence TypeProto."""
type_proto = TypeProto()
type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto)
return type_proto
[docs]def make_optional_type_proto(
inner_type_proto: TypeProto,
) -> TypeProto:
"""Makes an optional TypeProto."""
type_proto = TypeProto()
type_proto.optional_type.elem_type.CopyFrom(inner_type_proto)
return type_proto
[docs]def make_value_info(
name: Text,
type_proto: TypeProto,
doc_string: Text = "",
) -> ValueInfoProto:
"""Makes a ValueInfoProto with the given type_proto."""
value_info_proto = ValueInfoProto()
value_info_proto.name = name
if doc_string:
value_info_proto.doc_string = doc_string
value_info_proto.type.CopyFrom(type_proto)
return value_info_proto
def _sanitize_str(s: Union[Text, bytes]) -> Text:
if isinstance(s, str):
sanitized = s
elif isinstance(s, bytes):
sanitized = s.decode('utf-8', errors='ignore')
else:
sanitized = str(s)
if len(sanitized) < 64:
return sanitized
return sanitized[:64] + '...<+len=%d>' % (len(sanitized) - 64)
[docs]def make_tensor_sequence_value_info(
name: Text,
elem_type: int,
shape: Optional[Sequence[Union[Text, int, None]]],
doc_string: Text = "",
elem_shape_denotation: Optional[List[Text]] = None,
) -> ValueInfoProto:
"""Makes a Sequence[Tensors] ValueInfoProto based on the data type and shape."""
value_info_proto = ValueInfoProto()
value_info_proto.name = name
if doc_string:
value_info_proto.doc_string = doc_string
tensor_type_proto = make_tensor_type_proto(elem_type, shape, elem_shape_denotation)
sequence_type_proto = make_sequence_type_proto(tensor_type_proto)
value_info_proto.type.sequence_type.CopyFrom(sequence_type_proto.sequence_type)
return value_info_proto
[docs]def printable_attribute(attr: AttributeProto, subgraphs: bool = False) -> Union[Text, Tuple[Text, List[GraphProto]]]:
content = []
content.append(attr.name)
content.append("=")
def str_float(f: float) -> Text:
# NB: Different Python versions print different numbers of trailing
# decimals, specifying this explicitly keeps it consistent for all
# versions
return '{:.15g}'.format(f)
def str_int(i: int) -> Text:
# NB: In Python 2, longs will repr() as '2L', which is ugly and
# unnecessary. Explicitly format it to keep it consistent.
return '{:d}'.format(i)
def str_str(s: Text) -> Text:
return repr(s)
_T = TypeVar('_T') # noqa
def str_list(str_elem: Callable[[_T], Text], xs: Sequence[_T]) -> Text:
return '[' + ', '.join(map(str_elem, xs)) + ']'
# for now, this logic should continue to work as long as we are running on a proto3
# implementation. If/when we switch to proto3, we will need to use attr.type
# To support printing subgraphs, if we find a graph attribute, print out
# its name here and pass the graph itself up to the caller for later
# printing.
graphs = []
if attr.HasField("f"):
content.append(str_float(attr.f))
elif attr.HasField("i"):
content.append(str_int(attr.i))
elif attr.HasField("s"):
# TODO: Bit nervous about Python 2 / Python 3 determinism implications
content.append(repr(_sanitize_str(attr.s)))
elif attr.HasField("t"):
if len(attr.t.dims) > 0:
content.append("<Tensor>")
else:
# special case to print scalars
field = STORAGE_TENSOR_TYPE_TO_FIELD[attr.t.data_type]
content.append('<Scalar Tensor {}>'.format(str(getattr(attr.t, field))))
elif attr.HasField("g"):
content.append("<graph {}>".format(attr.g.name))
graphs.append(attr.g)
elif attr.HasField("tp"):
content.append("<Type Proto {}>".format(attr.tp))
elif attr.floats:
content.append(str_list(str_float, attr.floats))
elif attr.ints:
content.append(str_list(str_int, attr.ints))
elif attr.strings:
# TODO: Bit nervous about Python 2 / Python 3 determinism implications
content.append(str(list(map(_sanitize_str, attr.strings))))
elif attr.tensors:
content.append("[<Tensor>, ...]")
elif attr.type_protos:
content.append('[')
for i, tp in enumerate(attr.type_protos):
comma = ',' if i != len(attr.type_protos) - 1 else ''
content.append('<Type Proto {}>{}'.format(tp, comma))
content.append(']')
elif attr.graphs:
content.append('[')
for i, g in enumerate(attr.graphs):
comma = ',' if i != len(attr.graphs) - 1 else ''
content.append('<graph {}>{}'.format(g.name, comma))
content.append(']')
graphs.extend(attr.graphs)
else:
content.append("<Unknown>")
if subgraphs:
return ' '.join(content), graphs
else:
return ' '.join(content)
[docs]def printable_dim(dim: TensorShapeProto.Dimension) -> Text:
which = dim.WhichOneof('value')
assert which is not None
return str(getattr(dim, which))
[docs]def printable_type(t: TypeProto) -> Text:
if t.WhichOneof('value') == "tensor_type":
s = TensorProto.DataType.Name(t.tensor_type.elem_type)
if t.tensor_type.HasField('shape'):
if len(t.tensor_type.shape.dim):
s += str(', ' + 'x'.join(map(printable_dim, t.tensor_type.shape.dim)))
else:
s += str(', scalar')
return s
if t.WhichOneof('value') is None:
return ""
return 'Unknown type {}'.format(t.WhichOneof('value'))
[docs]def printable_value_info(v: ValueInfoProto) -> Text:
s = '%{}'.format(v.name)
if v.type:
s = '{}[{}]'.format(s, printable_type(v.type))
return s
[docs]def printable_tensor_proto(t: TensorProto) -> Text:
s = '%{}['.format(t.name)
s += TensorProto.DataType.Name(t.data_type)
if t.dims is not None:
if len(t.dims):
s += str(', ' + 'x'.join(map(str, t.dims)))
else:
s += str(', scalar')
s += ']'
return s
[docs]def printable_node(node: NodeProto, prefix: Text = '', subgraphs: bool = False) -> Union[Text, Tuple[Text, List[GraphProto]]]:
content = []
if len(node.output):
content.append(
', '.join(['%{}'.format(name) for name in node.output]))
content.append('=')
# To deal with nested graphs
graphs: List[GraphProto] = []
printed_attrs = []
for attr in node.attribute:
if subgraphs:
printed_attr_subgraphs = printable_attribute(attr, subgraphs)
assert isinstance(printed_attr_subgraphs[1], list)
graphs.extend(printed_attr_subgraphs[1])
printed_attrs.append(printed_attr_subgraphs[0])
else:
printed = printable_attribute(attr)
assert isinstance(printed, Text)
printed_attrs.append(printed)
printed_attributes = ', '.join(sorted(printed_attrs))
printed_inputs = ', '.join(['%{}'.format(name) for name in node.input])
if node.attribute:
content.append("{}[{}]({})".format(node.op_type, printed_attributes, printed_inputs))
else:
content.append("{}({})".format(node.op_type, printed_inputs))
if subgraphs:
return prefix + ' '.join(content), graphs
else:
return prefix + ' '.join(content)
[docs]def printable_graph(graph: GraphProto, prefix: Text = '') -> Text:
content = []
indent = prefix + ' '
# header
header = ['graph', graph.name]
initializers = {t.name for t in graph.initializer}
if len(graph.input):
header.append("(")
in_strs = [] # required inputs
in_with_init_strs = [] # optional inputs with initializer providing default value
for inp in graph.input:
if inp.name not in initializers:
in_strs.append(printable_value_info(inp))
else:
in_with_init_strs.append(printable_value_info(inp))
if in_strs:
content.append(prefix + ' '.join(header))
header = []
for line in in_strs:
content.append(prefix + ' ' + line)
header.append(")")
if in_with_init_strs:
header.append("optional inputs with matching initializers (")
content.append(prefix + ' '.join(header))
header = []
for line in in_with_init_strs:
content.append(prefix + ' ' + line)
header.append(")")
# from IR 4 onwards an initializer is not required to have a matching graph input
# so output the name, type and shape of those as well
if len(in_with_init_strs) < len(initializers):
graph_inputs = {i.name for i in graph.input}
init_strs = [printable_tensor_proto(i) for i in graph.initializer
if i.name not in graph_inputs]
header.append("initializers (")
content.append(prefix + ' '.join(header))
header = []
for line in init_strs:
content.append(prefix + ' ' + line)
header.append(")")
header.append('{')
content.append(prefix + ' '.join(header))
graphs: List[GraphProto] = []
# body
for node in graph.node:
contents_subgraphs = printable_node(node, indent, subgraphs=True)
assert isinstance(contents_subgraphs[1], list)
content.append(contents_subgraphs[0])
graphs.extend(contents_subgraphs[1])
# tail
tail = ['return']
if len(graph.output):
tail.append(
', '.join(['%{}'.format(out.name) for out in graph.output]))
content.append(indent + ' '.join(tail))
# closing bracket
content.append(prefix + '}')
for g in graphs:
content.append('\n' + printable_graph(g))
return '\n'.join(content)
def strip_doc_string(proto: google.protobuf.message.Message) -> None:
"""
Empties `doc_string` field on any nested protobuf messages
"""
assert isinstance(proto, google.protobuf.message.Message)
for descriptor in proto.DESCRIPTOR.fields:
if descriptor.name == 'doc_string':
proto.ClearField(descriptor.name)
elif descriptor.type == descriptor.TYPE_MESSAGE:
if descriptor.label == descriptor.LABEL_REPEATED:
for x in getattr(proto, descriptor.name):
strip_doc_string(x)
elif proto.HasField(descriptor.name):
strip_doc_string(getattr(proto, descriptor.name))
[docs]def make_training_info(algorithm: GraphProto, algorithm_bindings: AssignmentBindingType, initialization: Optional[GraphProto], initialization_bindings: Optional[AssignmentBindingType]) -> TrainingInfoProto:
training_info = TrainingInfoProto()
training_info.algorithm.CopyFrom(algorithm)
for k, v in algorithm_bindings:
binding = training_info.update_binding.add()
binding.key = k
binding.value = v
if initialization:
training_info.initialization.CopyFrom(initialization)
if initialization_bindings:
for k, v in initialization_bindings:
binding = training_info.initialization_binding.add()
binding.key = k
binding.value = v
return training_info
# For backwards compatibility
[docs]def make_sequence_value_info(
name: Text,
elem_type: int,
shape: Optional[Sequence[Union[Text, int, None]]],
doc_string: Text = "",
elem_shape_denotation: Optional[List[Text]] = None,
) -> ValueInfoProto:
"""Makes a Sequence[Tensors] ValueInfoProto based on the data type and shape."""
warnings.warn(str("`onnx.helper.make_sequence_value_info` is a deprecated alias for `onnx.helper.make_tensor_sequence_value_info`. To silence this warning, please use `make_tensor_sequence_value_info` for `TensorProto` sequences. Deprecated in ONNX v1.10.0, `onnx.helper.make_sequence_value_info alias` will be removed in an upcoming release."), DeprecationWarning, stacklevel=2)
return make_tensor_sequence_value_info(name, elem_type, shape, doc_string, elem_shape_denotation)