Source code for onnx.helper

# SPDX-License-Identifier: Apache-2.0
import collections.abc  # type: ignore
import numbers
import struct
from cmath import isnan
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
    cast,
)

import google.protobuf.message
import numpy as np  # type: ignore

from onnx import (
    IR_VERSION,
    AttributeProto,
    FunctionProto,
    GraphProto,
    MapProto,
    ModelProto,
    NodeProto,
    OperatorSetIdProto,
    OptionalProto,
    SequenceProto,
    SparseTensorProto,
    TensorProto,
    TensorShapeProto,
    TrainingInfoProto,
    TypeProto,
    ValueInfoProto,
    defs,
    mapping,
)
from onnx.mapping import STORAGE_TENSOR_TYPE_TO_FIELD

VersionRowType = Union[Tuple[str, int, int, int], Tuple[str, int, int, int, int]]
VersionTableType = List[VersionRowType]
AssignmentBindingType = List[Tuple[str, str]]

# This is a copy of the documented version in https://github.com/onnx/onnx/blob/main/docs/Versioning.md#released-versions
# Both must be updated whenever a new version of ONNX is released.
VERSION_TABLE: VersionTableType = [
    # Release-version, IR version, ai.onnx version, ai.onnx.ml version, (optional) ai.onnx.training version
    ("1.0", 3, 1, 1),
    ("1.1", 3, 5, 1),
    ("1.1.2", 3, 6, 1),
    ("1.2", 3, 7, 1),
    ("1.3", 3, 8, 1),
    ("1.4.1", 4, 9, 1),
    ("1.5.0", 5, 10, 1),
    ("1.6.0", 6, 11, 2),
    ("1.7.0", 7, 12, 2, 1),
    ("1.8.0", 7, 13, 2, 1),
    ("1.8.1", 7, 13, 2, 1),
    ("1.9.0", 7, 14, 2, 1),
    ("1.10.0", 8, 15, 2, 1),
    ("1.10.1", 8, 15, 2, 1),
    ("1.10.2", 8, 15, 2, 1),
    ("1.11.0", 8, 16, 3, 1),
    ("1.12.0", 8, 17, 3, 1),
]

VersionMapType = Dict[Tuple[str, int], int]


def create_op_set_id_version_map(table: VersionTableType) -> VersionMapType:
    """create a map from (opset-domain, opset-version) to ir-version from above table"""
    result: VersionMapType = dict()

    def process(release_version: str, ir_version: int, *args: Any) -> None:
        for pair in zip(["ai.onnx", "ai.onnx.ml", "ai.onnx.training"], args):
            if pair not in result:
                result[pair] = ir_version
                if pair[0] == "ai.onnx.training":
                    result["ai.onnx.preview.training", pair[1]] = ir_version

    for row in table:
        process(*row)
    return result


OP_SET_ID_VERSION_MAP = create_op_set_id_version_map(VERSION_TABLE)


[docs]def find_min_ir_version_for(opsetidlist: List[OperatorSetIdProto]) -> int: """Given list of opset ids, determine minimum IR version required""" default_min_version = 3 def find_min(domain: Union[str, None], version: int) -> int: key = (domain if domain else "ai.onnx", version) if key in OP_SET_ID_VERSION_MAP: return OP_SET_ID_VERSION_MAP[key] else: raise ValueError("Unsupported opset-version.") if opsetidlist: return max(find_min(x.domain, x.version) for x in opsetidlist) return default_min_version # if no opsets specified
[docs]def make_node( op_type: str, inputs: Sequence[str], outputs: Sequence[str], name: Optional[str] = None, doc_string: Optional[str] = None, domain: Optional[str] = None, **kwargs: Any, ) -> NodeProto: """Construct a NodeProto. Arguments: op_type (string): The name of the operator to construct inputs (list of string): list of input names outputs (list of string): list of output names name (string, default None): optional unique identifier for NodeProto doc_string (string, default None): optional documentation string for NodeProto domain (string, default None): optional domain for NodeProto. If it's None, we will just use default domain (which is empty) **kwargs (dict): the attributes of the node. The acceptable values are documented in :func:`make_attribute`. Returns: NodeProto """ node = NodeProto() node.op_type = op_type node.input.extend(inputs) node.output.extend(outputs) if name: node.name = name if doc_string: node.doc_string = doc_string if domain is not None: node.domain = domain if kwargs: node.attribute.extend( make_attribute(key, value) for key, value in sorted(kwargs.items()) if value is not None ) return node
[docs]def make_operatorsetid( domain: str, version: int, ) -> OperatorSetIdProto: """Construct an OperatorSetIdProto. Arguments: domain (string): The domain of the operator set id version (integer): Version of operator set id Returns: OperatorSetIdProto """ operatorsetid = OperatorSetIdProto() operatorsetid.domain = domain operatorsetid.version = version return operatorsetid
[docs]def make_graph( nodes: Sequence[NodeProto], name: str, inputs: Sequence[ValueInfoProto], outputs: Sequence[ValueInfoProto], initializer: Optional[Sequence[TensorProto]] = None, doc_string: Optional[str] = None, value_info: Sequence[ValueInfoProto] = [], sparse_initializer: Optional[Sequence[SparseTensorProto]] = None, ) -> GraphProto: """Construct a GraphProto Arguments: nodes: list of NodeProto name (string): graph name inputs: list of ValueInfoProto outputs: list of ValueInfoProto initializer: list of TensorProto doc_string (string): graph documentation value_info: list of ValueInfoProto sparse_initializer: list of SparseTensorProto Returns: GraphProto """ if initializer is None: initializer = [] if sparse_initializer is None: sparse_initializer = [] if value_info is None: value_info = [] graph = GraphProto() graph.node.extend(nodes) graph.name = name graph.input.extend(inputs) graph.output.extend(outputs) graph.initializer.extend(initializer) graph.sparse_initializer.extend(sparse_initializer) graph.value_info.extend(value_info) if doc_string: graph.doc_string = doc_string return graph
[docs]def make_opsetid(domain: str, version: int) -> OperatorSetIdProto: """Construct an OperatorSetIdProto. Arguments: domain (string): The domain of the operator set id version (integer): Version of operator set id Returns: OperatorSetIdProto """ opsetid = OperatorSetIdProto() opsetid.domain = domain opsetid.version = version return opsetid
[docs]def make_function( domain: str, fname: str, inputs: Sequence[str], outputs: Sequence[str], nodes: Sequence[NodeProto], opset_imports: Sequence[OperatorSetIdProto], attributes: Optional[Sequence[str]] = [], doc_string: Optional[str] = None, ) -> FunctionProto: f = FunctionProto() f.domain = domain f.name = fname f.input.extend(inputs) f.output.extend(outputs) f.node.extend(nodes) f.opset_import.extend(opset_imports) f.attribute.extend(attributes) if doc_string: f.doc_string = doc_string return f
[docs]def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto: """Construct a ModelProto Arguments: graph (GraphProto): *make_graph* returns **kwargs: any attribute to add to the returned instance Returns: ModelProto """ model = ModelProto() # Touch model.ir_version so it is stored as the version from which it is # generated. model.ir_version = IR_VERSION model.graph.CopyFrom(graph) opset_imports: Optional[Sequence[OperatorSetIdProto]] = None opset_imports = kwargs.pop("opset_imports", None) # type: ignore if opset_imports is not None: model.opset_import.extend(opset_imports) else: # Default import imp = model.opset_import.add() imp.version = defs.onnx_opset_version() functions: Optional[Sequence[FunctionProto]] = None functions = kwargs.pop("functions", None) # type: ignore if functions is not None: model.functions.extend(functions) for k, v in kwargs.items(): # TODO: Does this work with repeated fields? setattr(model, k, v) return model
# An extension of make_model that infers an IR_VERSION for the model, # if not specified, using a best-effort-basis. def make_model_gen_version(graph: GraphProto, **kwargs: Any) -> ModelProto: ir_version_field = "ir_version" if ir_version_field not in kwargs: opset_imports_field = "opset_imports" imports = kwargs[opset_imports_field] if opset_imports_field in kwargs else [] kwargs[ir_version_field] = find_min_ir_version_for(imports) return make_model(graph, **kwargs) def set_model_props(model: ModelProto, dict_value: Dict[str, str]) -> None: del model.metadata_props[:] for (k, v) in dict_value.items(): entry = model.metadata_props.add() entry.key = k entry.value = v # model.metadata_properties.append(entry)
[docs]def split_complex_to_pairs(ca: Sequence[np.complex64]) -> Sequence[int]: return [ (ca[i // 2].real if (i % 2 == 0) else ca[i // 2].imag) for i in range(len(ca) * 2) ]
# convert a float32 value to a bfloat16 (as int) # By default, this conversion rounds-to-nearest-even and supports NaN # Setting `truncate` to True enables a simpler conversion. In this mode the # conversion is performed by simply dropping the 2 least significant bytes of # the significand. In this mode an error of up to 1 bit may be introduced and # preservation of NaN values is not be guaranteed. def float32_to_bfloat16(fval: float, truncate: bool = False) -> int: ival = int.from_bytes(struct.pack("<f", fval), "little") if truncate: return ival >> 16 # NaN requires at least 1 significand bit set if isnan(fval): return 0x7FC0 # sign=0, exp=all-ones, sig=0b1000000 # drop bottom 16-bits # round remaining bits using round-to-nearest-even round = ((ival >> 16) & 1) + 0x7FFF return (ival + round) >> 16
[docs]def make_tensor( name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = False ) -> TensorProto: """ Make a TensorProto with specified arguments. If raw is False, this function will choose the corresponding proto field to store the values based on data_type. If raw is True, use "raw_data" proto field to store the values, and values should be of type bytes in this case. Arguments: name (string): tensor name data_type (int): a value such as onnx.TensorProto.FLOAT dims (List[int]): shape vals: values raw (bool): if True, vals contains the serialized content of the tensor, otherwise, vals should be a list of values of the type defined by *data_type* Returns: TensorProto """ tensor = TensorProto() tensor.data_type = data_type tensor.name = name if data_type == TensorProto.STRING: assert not raw, "Can not use raw_data to store string type" np_dtype = mapping.TENSOR_TYPE_TO_NP_TYPE[data_type] # Check number of vals specified equals tensor size expected_size = 1 if raw: # NumPy doesn't have BFLOAT16. TENSOR_TYPE_TO_NP_TYPE maps it to float32, # which has the wrong itemsize. if data_type == TensorProto.BFLOAT16: expected_size = 2 else: expected_size = np_dtype.itemsize if type(vals) is np.ndarray and len(vals.shape) > 1: vals = vals.flatten() for d in dims: expected_size *= d if len(vals) != expected_size: raise ValueError( "Number of values does not match tensor's size. Expected {}, but it is {}. ".format( expected_size, len(vals) ) ) if raw: tensor.raw_data = vals else: if data_type == TensorProto.COMPLEX64 or data_type == TensorProto.COMPLEX128: vals = split_complex_to_pairs(vals) elif data_type == TensorProto.FLOAT16: vals = ( np.array(vals).astype(np_dtype).view(dtype=np.uint16).flatten().tolist() ) elif data_type == TensorProto.BFLOAT16: vals = list( map( float32_to_bfloat16, np.array(vals).astype(np_dtype).flatten().tolist(), ) ) field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[ mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[data_type] ] getattr(tensor, field).extend(vals) tensor.dims.extend(dims) return tensor
[docs]def make_sparse_tensor( values: TensorProto, indices: TensorProto, dims: Sequence[int] ) -> SparseTensorProto: """Construct a SparseTensorProto Arguments: values (TensorProto): the values indices (TensorProto): the indices dims: the shape Returns: SparseTensorProto """ sparse = SparseTensorProto() sparse.values.CopyFrom(values) sparse.indices.CopyFrom(indices) sparse.dims.extend(dims) return sparse
[docs]def make_sequence( name: str, elem_type: SequenceProto.DataType, values: Sequence[Any], ) -> SequenceProto: """ Make a Sequence with specified value arguments. """ sequence = SequenceProto() sequence.name = name sequence.elem_type = elem_type values_field = mapping.STORAGE_ELEMENT_TYPE_TO_FIELD[elem_type] getattr(sequence, values_field).extend(values) return sequence
[docs]def make_map( name: str, key_type: int, keys: List[Any], values: SequenceProto ) -> MapProto: """ Make a Map with specified key-value pair arguments. Criteria for conversion: - Keys and Values must have the same number of elements - Every key in keys must be of the same type - Every value in values must be of the same type """ map = MapProto() valid_key_int_types = [ TensorProto.INT8, TensorProto.INT16, TensorProto.INT32, TensorProto.INT64, TensorProto.UINT8, TensorProto.UINT16, TensorProto.UINT32, TensorProto.UINT64, ] map.name = name map.key_type = key_type if key_type == TensorProto.STRING: map.string_keys.extend(keys) elif key_type in valid_key_int_types: map.keys.extend(keys) map.values.CopyFrom(values) return map
[docs]def make_optional( name: str, elem_type: OptionalProto.DataType, value: Optional[Any], ) -> OptionalProto: """ Make an Optional with specified value arguments. """ optional = OptionalProto() optional.name = name optional.elem_type = elem_type if elem_type != 0: values_field = mapping.OPTIONAL_ELEMENT_TYPE_TO_FIELD[elem_type] getattr(optional, values_field).CopyFrom(value) return optional
def _to_bytes_or_false(val: Union[str, bytes]) -> Union[bytes, bool]: """An internal graph to convert the input to a bytes or to False. The criteria for conversion is as follows and should be python 2 and 3 compatible: - If val is py2 str or py3 bytes: return bytes - If val is py2 unicode or py3 str: return val.decode('utf-8') - Otherwise, return False """ if isinstance(val, bytes): return val try: return val.encode("utf-8") except AttributeError: return False
[docs]def make_attribute( key: str, value: Any, doc_string: Optional[str] = None ) -> AttributeProto: """Makes an AttributeProto based on the value type.""" attr = AttributeProto() attr.name = key if doc_string: attr.doc_string = doc_string is_iterable = isinstance(value, collections.abc.Iterable) bytes_or_false = _to_bytes_or_false(value) # First, singular cases # float if isinstance(value, float): attr.f = value attr.type = AttributeProto.FLOAT # integer elif isinstance(value, numbers.Integral): attr.i = cast(int, value) attr.type = AttributeProto.INT # string elif bytes_or_false is not False: assert isinstance(bytes_or_false, bytes) attr.s = bytes_or_false attr.type = AttributeProto.STRING elif isinstance(value, TensorProto): attr.t.CopyFrom(value) attr.type = AttributeProto.TENSOR elif isinstance(value, SparseTensorProto): attr.sparse_tensor.CopyFrom(value) attr.type = AttributeProto.SPARSE_TENSOR elif isinstance(value, GraphProto): attr.g.CopyFrom(value) attr.type = AttributeProto.GRAPH elif isinstance(value, TypeProto): attr.tp.CopyFrom(value) attr.type = AttributeProto.TYPE_PROTO # third, iterable cases elif is_iterable: byte_array = [_to_bytes_or_false(v) for v in value] if all(isinstance(v, numbers.Integral) for v in value): # Turn np.int32/64 into Python built-in int. attr.ints.extend(int(v) for v in value) attr.type = AttributeProto.INTS elif all(isinstance(v, numbers.Real) for v in value): # Since ints and floats are members of Real, this allows a mix of ints and floats # (and converts the ints to floats). attr.floats.extend(float(v) for v in value) attr.type = AttributeProto.FLOATS elif all(map(lambda bytes_or_false: bytes_or_false is not False, byte_array)): attr.strings.extend(cast(List[bytes], byte_array)) attr.type = AttributeProto.STRINGS elif all(isinstance(v, TensorProto) for v in value): attr.tensors.extend(value) attr.type = AttributeProto.TENSORS elif all(isinstance(v, SparseTensorProto) for v in value): attr.sparse_tensors.extend(value) attr.type = AttributeProto.SPARSE_TENSORS elif all(isinstance(v, GraphProto) for v in value): attr.graphs.extend(value) attr.type = AttributeProto.GRAPHS elif all(isinstance(tp, TypeProto) for tp in value): attr.type_protos.extend(value) attr.type = AttributeProto.TYPE_PROTOS else: raise ValueError( "You passed in an iterable attribute but I cannot figure out " "its applicable type." ) else: raise TypeError(f'value "{value}" is not valid attribute data type.') return attr
def make_attribute_ref( name: str, attr_type: AttributeProto.AttributeType, doc_string: Optional[str] = None ) -> AttributeProto: """Make an AttributeProto holding a reference to the parent function's attribute of given name and type.""" attr = AttributeProto() attr.name = name attr.type = attr_type if doc_string: attr.doc_string = doc_string return attr
[docs]def get_attribute_value(attr: AttributeProto) -> Any: if attr.ref_attr_name: raise ValueError(f"Cannot get value of reference attribute: {attr}") if attr.type == AttributeProto.FLOAT: return attr.f if attr.type == AttributeProto.INT: return attr.i if attr.type == AttributeProto.STRING: return attr.s if attr.type == AttributeProto.TENSOR: return attr.t if attr.type == AttributeProto.SPARSE_TENSOR: return attr.sparse_tensor if attr.type == AttributeProto.GRAPH: return attr.g if attr.type == AttributeProto.TYPE_PROTO: return attr.tp if attr.type == AttributeProto.FLOATS: return list(attr.floats) if attr.type == AttributeProto.INTS: return list(attr.ints) if attr.type == AttributeProto.STRINGS: return list(attr.strings) if attr.type == AttributeProto.TENSORS: return list(attr.tensors) if attr.type == AttributeProto.SPARSE_TENSORS: return list(attr.sparse_tensors) if attr.type == AttributeProto.GRAPHS: return list(attr.graphs) if attr.type == AttributeProto.TYPE_PROTOS: return list(attr.type_protos) raise ValueError(f"Unsupported ONNX attribute: {attr}")
[docs]def make_empty_tensor_value_info(name: str) -> ValueInfoProto: value_info_proto = ValueInfoProto() value_info_proto.name = name return value_info_proto
[docs]def make_tensor_type_proto( elem_type: int, shape: Optional[Sequence[Union[str, int, None]]], shape_denotation: Optional[List[str]] = None, ) -> TypeProto: """Makes a Tensor TypeProto based on the data type and shape.""" type_proto = TypeProto() tensor_type_proto = type_proto.tensor_type tensor_type_proto.elem_type = elem_type tensor_shape_proto = tensor_type_proto.shape if shape is not None: # You might think this is a no-op (extending a normal Python # list by [] certainly is), but protobuf lists work a little # differently; if a field is never set, it is omitted from the # resulting protobuf; a list that is explicitly set to be # empty will get an (empty) entry in the protobuf. This # difference is visible to our consumers, so make sure we emit # an empty shape! tensor_shape_proto.dim.extend([]) if shape_denotation: if len(shape_denotation) != len(shape): raise ValueError( "Invalid shape_denotation. " "Must be of the same length as shape." ) for i, d in enumerate(shape): dim = tensor_shape_proto.dim.add() if d is None: pass elif isinstance(d, int): dim.dim_value = d elif isinstance(d, str): dim.dim_param = d else: raise ValueError( f"Invalid item in shape: {d}. Needs to be of int or str." ) if shape_denotation: dim.denotation = shape_denotation[i] return type_proto
[docs]def make_tensor_value_info( name: str, elem_type: int, shape: Optional[Sequence[Union[str, int, None]]], doc_string: str = "", shape_denotation: Optional[List[str]] = None, ) -> ValueInfoProto: """Makes a ValueInfoProto based on the data type and shape.""" value_info_proto = ValueInfoProto() value_info_proto.name = name if doc_string: value_info_proto.doc_string = doc_string tensor_type_proto = make_tensor_type_proto(elem_type, shape, shape_denotation) value_info_proto.type.CopyFrom(tensor_type_proto) return value_info_proto
[docs]def make_sparse_tensor_type_proto( elem_type: int, shape: Optional[Sequence[Union[str, int, None]]], shape_denotation: Optional[List[str]] = None, ) -> TypeProto: """Makes a SparseTensor TypeProto based on the data type and shape.""" type_proto = TypeProto() sparse_tensor_type_proto = type_proto.sparse_tensor_type sparse_tensor_type_proto.elem_type = elem_type sparse_tensor_shape_proto = sparse_tensor_type_proto.shape if shape is not None: # You might think this is a no-op (extending a normal Python # list by [] certainly is), but protobuf lists work a little # differently; if a field is never set, it is omitted from the # resulting protobuf; a list that is explicitly set to be # empty will get an (empty) entry in the protobuf. This # difference is visible to our consumers, so make sure we emit # an empty shape! sparse_tensor_shape_proto.dim.extend([]) if shape_denotation: if len(shape_denotation) != len(shape): raise ValueError( "Invalid shape_denotation. " "Must be of the same length as shape." ) for i, d in enumerate(shape): dim = sparse_tensor_shape_proto.dim.add() if d is None: pass elif isinstance(d, int): dim.dim_value = d elif isinstance(d, str): dim.dim_param = d else: raise ValueError( f"Invalid item in shape: {d}. Needs to be of int or text." ) if shape_denotation: dim.denotation = shape_denotation[i] return type_proto
[docs]def make_sparse_tensor_value_info( name: str, elem_type: int, shape: Optional[Sequence[Union[str, int, None]]], doc_string: str = "", shape_denotation: Optional[List[str]] = None, ) -> ValueInfoProto: """Makes a SparseTensor ValueInfoProto based on the data type and shape.""" value_info_proto = ValueInfoProto() value_info_proto.name = name if doc_string: value_info_proto.doc_string = doc_string sparse_tensor_type_proto = make_sparse_tensor_type_proto( elem_type, shape, shape_denotation ) value_info_proto.type.sparse_tensor_type.CopyFrom( sparse_tensor_type_proto.sparse_tensor_type ) return value_info_proto
[docs]def make_sequence_type_proto( inner_type_proto: TypeProto, ) -> TypeProto: """Makes a sequence TypeProto.""" type_proto = TypeProto() type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto) return type_proto
[docs]def make_optional_type_proto( inner_type_proto: TypeProto, ) -> TypeProto: """Makes an optional TypeProto.""" type_proto = TypeProto() type_proto.optional_type.elem_type.CopyFrom(inner_type_proto) return type_proto
[docs]def make_value_info( name: str, type_proto: TypeProto, doc_string: str = "", ) -> ValueInfoProto: """Makes a ValueInfoProto with the given type_proto.""" value_info_proto = ValueInfoProto() value_info_proto.name = name if doc_string: value_info_proto.doc_string = doc_string value_info_proto.type.CopyFrom(type_proto) return value_info_proto
def _sanitize_str(s: Union[str, bytes]) -> str: if isinstance(s, str): sanitized = s elif isinstance(s, bytes): sanitized = s.decode("utf-8", errors="ignore") else: sanitized = str(s) if len(sanitized) < 64: return sanitized return sanitized[:64] + "...<+len=%d>" % (len(sanitized) - 64)
[docs]def make_tensor_sequence_value_info( name: str, elem_type: int, shape: Optional[Sequence[Union[str, int, None]]], doc_string: str = "", elem_shape_denotation: Optional[List[str]] = None, ) -> ValueInfoProto: """Makes a Sequence[Tensors] ValueInfoProto based on the data type and shape.""" value_info_proto = ValueInfoProto() value_info_proto.name = name if doc_string: value_info_proto.doc_string = doc_string tensor_type_proto = make_tensor_type_proto(elem_type, shape, elem_shape_denotation) sequence_type_proto = make_sequence_type_proto(tensor_type_proto) value_info_proto.type.sequence_type.CopyFrom(sequence_type_proto.sequence_type) return value_info_proto
[docs]def printable_attribute( attr: AttributeProto, subgraphs: bool = False ) -> Union[str, Tuple[str, List[GraphProto]]]: content = [] content.append(attr.name) content.append("=") def str_float(f: float) -> str: # NB: Different Python versions print different numbers of trailing # decimals, specifying this explicitly keeps it consistent for all # versions return f"{f:.15g}" def str_int(i: int) -> str: return str(i) _T = TypeVar("_T") # noqa def str_list(str_elem: Callable[[_T], str], xs: Sequence[_T]) -> str: return "[" + ", ".join(map(str_elem, xs)) + "]" # for now, this logic should continue to work as long as we are running on a proto3 # implementation. If/when we switch to proto3, we will need to use attr.type # To support printing subgraphs, if we find a graph attribute, print out # its name here and pass the graph itself up to the caller for later # printing. graphs = [] if attr.HasField("f"): content.append(str_float(attr.f)) elif attr.HasField("i"): content.append(str_int(attr.i)) elif attr.HasField("s"): # TODO: Bit nervous about Python 2 / Python 3 determinism implications content.append(repr(_sanitize_str(attr.s))) elif attr.HasField("t"): if len(attr.t.dims) > 0: content.append("<Tensor>") else: # special case to print scalars field = STORAGE_TENSOR_TYPE_TO_FIELD[attr.t.data_type] content.append(f"<Scalar Tensor {str(getattr(attr.t, field))}>") elif attr.HasField("g"): content.append(f"<graph {attr.g.name}>") graphs.append(attr.g) elif attr.HasField("tp"): content.append(f"<Type Proto {attr.tp}>") elif attr.floats: content.append(str_list(str_float, attr.floats)) elif attr.ints: content.append(str_list(str_int, attr.ints)) elif attr.strings: # TODO: Bit nervous about Python 2 / Python 3 determinism implications content.append(str(list(map(_sanitize_str, attr.strings)))) elif attr.tensors: content.append("[<Tensor>, ...]") elif attr.type_protos: content.append("[") for i, tp in enumerate(attr.type_protos): comma = "," if i != len(attr.type_protos) - 1 else "" content.append(f"<Type Proto {tp}>{comma}") content.append("]") elif attr.graphs: content.append("[") for i, g in enumerate(attr.graphs): comma = "," if i != len(attr.graphs) - 1 else "" content.append(f"<graph {g.name}>{comma}") content.append("]") graphs.extend(attr.graphs) else: content.append("<Unknown>") if subgraphs: return " ".join(content), graphs else: return " ".join(content)
[docs]def printable_dim(dim: TensorShapeProto.Dimension) -> str: which = dim.WhichOneof("value") assert which is not None return str(getattr(dim, which))
[docs]def printable_type(t: TypeProto) -> str: if t.WhichOneof("value") == "tensor_type": s = TensorProto.DataType.Name(t.tensor_type.elem_type) if t.tensor_type.HasField("shape"): if len(t.tensor_type.shape.dim): s += str(", " + "x".join(map(printable_dim, t.tensor_type.shape.dim))) else: s += ", scalar" return s if t.WhichOneof("value") is None: return "" return f"Unknown type {t.WhichOneof('value')}"
[docs]def printable_value_info(v: ValueInfoProto) -> str: s = f"%{v.name}" if v.type: s = f"{s}[{printable_type(v.type)}]" return s
[docs]def printable_tensor_proto(t: TensorProto) -> str: s = f"%{t.name}[" s += TensorProto.DataType.Name(t.data_type) if t.dims is not None: if len(t.dims): s += str(", " + "x".join(map(str, t.dims))) else: s += ", scalar" s += "]" return s
[docs]def printable_node( node: NodeProto, prefix: str = "", subgraphs: bool = False ) -> Union[str, Tuple[str, List[GraphProto]]]: content = [] if len(node.output): content.append(", ".join([f"%{name}" for name in node.output])) content.append("=") # To deal with nested graphs graphs: List[GraphProto] = [] printed_attrs = [] for attr in node.attribute: if subgraphs: printed_attr_subgraphs = printable_attribute(attr, subgraphs) assert isinstance(printed_attr_subgraphs[1], list) graphs.extend(printed_attr_subgraphs[1]) printed_attrs.append(printed_attr_subgraphs[0]) else: printed = printable_attribute(attr) assert isinstance(printed, str) printed_attrs.append(printed) printed_attributes = ", ".join(sorted(printed_attrs)) printed_inputs = ", ".join([f"%{name}" for name in node.input]) if node.attribute: content.append(f"{node.op_type}[{printed_attributes}]({printed_inputs})") else: content.append(f"{node.op_type}({printed_inputs})") if subgraphs: return prefix + " ".join(content), graphs else: return prefix + " ".join(content)
[docs]def printable_graph(graph: GraphProto, prefix: str = "") -> str: """ Display a GraphProto as a string. Arguments: graph (GraphProto): the graph to display prefix (string): prefix of every line Returns: string """ content = [] indent = prefix + " " # header header = ["graph", graph.name] initializers = {t.name for t in graph.initializer} if len(graph.input): header.append("(") in_strs = [] # required inputs in_with_init_strs = ( [] ) # optional inputs with initializer providing default value for inp in graph.input: if inp.name not in initializers: in_strs.append(printable_value_info(inp)) else: in_with_init_strs.append(printable_value_info(inp)) if in_strs: content.append(prefix + " ".join(header)) header = [] for line in in_strs: content.append(prefix + " " + line) header.append(")") if in_with_init_strs: header.append("optional inputs with matching initializers (") content.append(prefix + " ".join(header)) header = [] for line in in_with_init_strs: content.append(prefix + " " + line) header.append(")") # from IR 4 onwards an initializer is not required to have a matching graph input # so output the name, type and shape of those as well if len(in_with_init_strs) < len(initializers): graph_inputs = {i.name for i in graph.input} init_strs = [ printable_tensor_proto(i) for i in graph.initializer if i.name not in graph_inputs ] header.append("initializers (") content.append(prefix + " ".join(header)) header = [] for line in init_strs: content.append(prefix + " " + line) header.append(")") header.append("{") content.append(prefix + " ".join(header)) graphs: List[GraphProto] = [] # body for node in graph.node: contents_subgraphs = printable_node(node, indent, subgraphs=True) assert isinstance(contents_subgraphs[1], list) content.append(contents_subgraphs[0]) graphs.extend(contents_subgraphs[1]) # tail tail = ["return"] if len(graph.output): tail.append(", ".join([f"%{out.name}" for out in graph.output])) content.append(indent + " ".join(tail)) # closing bracket content.append(prefix + "}") for g in graphs: content.append("\n" + printable_graph(g)) return "\n".join(content)
def strip_doc_string(proto: google.protobuf.message.Message) -> None: """ Empties `doc_string` field on any nested protobuf messages """ assert isinstance(proto, google.protobuf.message.Message) for descriptor in proto.DESCRIPTOR.fields: if descriptor.name == "doc_string": proto.ClearField(descriptor.name) elif descriptor.type == descriptor.TYPE_MESSAGE: if descriptor.label == descriptor.LABEL_REPEATED: for x in getattr(proto, descriptor.name): strip_doc_string(x) elif proto.HasField(descriptor.name): strip_doc_string(getattr(proto, descriptor.name))
[docs]def make_training_info( algorithm: GraphProto, algorithm_bindings: AssignmentBindingType, initialization: Optional[GraphProto], initialization_bindings: Optional[AssignmentBindingType], ) -> TrainingInfoProto: training_info = TrainingInfoProto() training_info.algorithm.CopyFrom(algorithm) for k, v in algorithm_bindings: binding = training_info.update_binding.add() binding.key = k binding.value = v if initialization: training_info.initialization.CopyFrom(initialization) if initialization_bindings: for k, v in initialization_bindings: binding = training_info.initialization_binding.add() binding.key = k binding.value = v return training_info