Source code for onnx.compose

# SPDX-License-Identifier: Apache-2.0

from typing import List, MutableMapping, Optional, Set, Tuple

from onnx import GraphProto, ModelProto
from onnx import TensorProto as tp
from onnx import checker, helper, utils


def check_overlapping_names(
    g1: GraphProto, g2: GraphProto, io_map: Optional[List[Tuple[str, str]]] = None
) -> List[Tuple[str, List[str]]]:
    """Checks whether there are name collisions between two graphs

    Returns a list of tuples where the first element represents the member containing overlapping names
    (One of: "node", "edge", "value_info", "initializer", "sparse_initializer"), and the
    second element contains a list of names that appear in both graphs on that category.

    Optionally, it takes an io_map, representing the output/inputs to be connected. It provided, overlapping
    present in the io_map argument will be ignored.
    """
    if type(g1) is not GraphProto:
        raise ValueError("g1 argument is not an ONNX graph")
    if type(g2) is not GraphProto:
        raise ValueError("g2 argument is not an ONNX graph")

    def _overlapping(c1: List[str], c2: List[str]) -> List[str]:
        return list(set(c1) & set(c2))

    def _edge_names(graph: GraphProto, exclude: Set[str] = set()) -> List[str]:
        edges = []
        for n in graph.node:
            for i in n.input:
                if i != "" and i not in exclude:
                    edges.append(i)
            for o in n.output:
                if o != "" and o not in exclude:
                    edges.append(o)
        return edges

    result = []

    if not io_map:
        io_map = []
    io_map_inputs = {elem[1] for elem in io_map}

    # Edges already cover input/output
    overlap = _overlapping(_edge_names(g1), _edge_names(g2, exclude=io_map_inputs))
    if len(overlap) > 0:
        result.append(("edge", overlap))

    overlap = _overlapping(
        [e.name for e in g1.value_info], [e.name for e in g2.value_info]
    )
    if len(overlap) > 0:
        result.append(("value_info", overlap))

    overlap = _overlapping(
        [e.name for e in g1.initializer], [e.name for e in g2.initializer]
    )
    if len(overlap) > 0:
        result.append(("initializer", overlap))

    overlap = _overlapping(
        [e.values.name for e in g1.sparse_initializer],
        [e.values.name for e in g2.sparse_initializer],
    ) + _overlapping(
        [e.indices.name for e in g1.sparse_initializer],
        [e.indices.name for e in g2.sparse_initializer],
    )
    if len(overlap) > 0:
        result.append(("sparse_initializer", overlap))

    return result


[docs]def merge_graphs( g1: GraphProto, g2: GraphProto, io_map: List[Tuple[str, str]], inputs: Optional[List[str]] = None, outputs: Optional[List[str]] = None, prefix1: Optional[str] = None, prefix2: Optional[str] = None, name: Optional[str] = None, doc_string: Optional[str] = None, ) -> GraphProto: """Combines two ONNX graphs into a single one. The combined graph is defined by connecting the specified set of outputs/inputs. Those inputs/outputs not specified in the io_map argument will remain as inputs/outputs of the combined graph. Arguments: g1 (GraphProto): First graph g2 (GraphProto): Second graph io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...] representing outputs of the first graph and inputs of the second to be connected inputs (list of string): Optional list of inputs to be included in the combined graph By default, all inputs not present in the ``io_map`` argument will be included in the combined model outputs (list of string): Optional list of outputs to be included in the combined graph By default, all outputs not present in the ``io_map`` argument will be included in the combined model prefix1 (string): Optional prefix to be added to all names in g1 prefix2 (string): Optional prefix to be added to all names in g2 name (string): Optional name for the combined graph By default, the name is g1.name and g2.name concatenated with an undescore delimiter doc_string (string): Optional docstring for the combined graph If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used Returns: GraphProto """ if type(g1) is not GraphProto: raise ValueError("g1 argument is not an ONNX graph") if type(g2) is not GraphProto: raise ValueError("g2 argument is not an ONNX graph") # Prefixing names in the graph if requested, adjusting io_map accordingly if prefix1 or prefix2: if prefix1: g1_copy = GraphProto() g1_copy.CopyFrom(g1) g1 = g1_copy g1 = add_prefix_graph(g1, prefix=prefix1) if prefix2: g2_copy = GraphProto() g2_copy.CopyFrom(g2) g2 = g2_copy g2 = add_prefix_graph(g2, prefix=prefix2) io_map = [ ( prefix1 + io[0] if prefix1 else io[0], prefix2 + io[1] if prefix2 else io[1], ) for io in io_map ] io_map_g1_outs = {io[0] for io in io_map} io_map_g2_ins = {io[1] for io in io_map} reversed_io_map = {in_name: out_name for out_name, in_name in io_map} g1_outs = {o.name for o in g1.output} g2_ins = {i.name for i in g2.input} # If necessary extract subgraphs if inputs or outputs: if not inputs: g1_inputs = [i.name for i in g1.input] g2_inputs = [i.name for i in g2.input] else: input_set = set(inputs) g1_inputs = [i.name for i in g1.input if i.name in input_set] g2_inputs = [ i.name for i in g2.input if i.name in input_set or i.name in io_map_g2_ins ] if not outputs: g1_outputs = [o.name for o in g1.input] g2_outputs = [o.name for o in g2.input] else: output_set = set(outputs) g1_outputs = [ o.name for o in g1.output if o.name in output_set or o.name in io_map_g1_outs ] g2_outputs = [o.name for o in g2.output if o.name in output_set] if len(g1_inputs) < len(g1.input) or len(g1_outputs) < len(g1.output): e1 = utils.Extractor(helper.make_model(g1)) g1 = e1.extract_model(g1_inputs, g1_outputs).graph if len(g2_inputs) < len(g2.input) or len(g2_outputs) < len(g2.output): e2 = utils.Extractor(helper.make_model(g2)) g2 = e2.extract_model(g2_inputs, g2_outputs).graph # Check that input/output names specified in the io_map argument are valid input/output names for g1_out_name, g2_in_name in io_map: if g1_out_name not in g1_outs: raise ValueError(f"Output {g1_out_name} is not present in g1") if g2_in_name not in g2_ins: raise ValueError(f"Input {g2_in_name} is not present in g2") # Check for name collision overlapping_names = check_overlapping_names(g1, g2, io_map) if len(overlapping_names) > 0: category, names = overlapping_names[0] raise ValueError( "Cant merge two graphs with overlapping names. " f"Found repeated {category} names: " + ", ".join(names) + "\n" + "Consider using ``onnx.compose.add_prefix`` to add a prefix to names in one of the graphs." ) g = GraphProto() g.node.extend(g1.node) g2_nodes_begin = len(g.node) g.node.extend(g2.node) g2_nodes_end = len(g.node) # Connecting outputs of the first graph with the inputs of the second for node_idx in range(g2_nodes_begin, g2_nodes_end): node = g.node[node_idx] for index, name in enumerate(node.input): if name in reversed_io_map: node.input[index] = reversed_io_map[name] if inputs: input_set = set(inputs) g.input.extend([i for i in g1.input if i.name in input_set]) g.input.extend([i for i in g2.input if i.name in input_set]) else: g.input.extend(g1.input) g.input.extend([i for i in g2.input if i.name not in io_map_g2_ins]) if outputs: output_set = set(outputs) g.output.extend([o for o in g1.output if o.name in output_set]) g.output.extend([o for o in g2.output if o.name in output_set]) else: g.output.extend([o for o in g1.output if o.name not in io_map_g1_outs]) g.output.extend(g2.output) g.initializer.extend(g1.initializer) g.initializer.extend( [init for init in g2.initializer if init.name not in io_map_g2_ins] ) g.sparse_initializer.extend(g1.sparse_initializer) g.sparse_initializer.extend( [ init for init in g2.sparse_initializer if init.values.name not in io_map_g2_ins ] ) g.value_info.extend(g1.value_info) g.value_info.extend([vi for vi in g2.value_info if vi.name not in io_map_g2_ins]) g.name = name if name is not None else "_".join([g1.name, g2.name]) if doc_string is None: doc_string = ( f"Graph combining {g1.name} and {g2.name}\n" + g1.name + "\n\n" + g1.doc_string + "\n\n" + g2.name + "\n\n" + g2.doc_string ) g.doc_string = doc_string return g
[docs]def merge_models( m1: ModelProto, m2: ModelProto, io_map: List[Tuple[str, str]], inputs: Optional[List[str]] = None, outputs: Optional[List[str]] = None, prefix1: Optional[str] = None, prefix2: Optional[str] = None, name: Optional[str] = None, doc_string: Optional[str] = None, producer_name: Optional[str] = "onnx.compose.merge_models", producer_version: Optional[str] = "1.0", domain: Optional[str] = "", model_version: Optional[int] = 1, ) -> ModelProto: """Combines two ONNX models into a single one. The combined model is defined by connecting the specified set of outputs/inputs. Those inputs/outputs not specified in the io_map argument will remain as inputs/outputs of the combined model. Both models should have the same IR version, and same operator sets imported. Arguments: m1 (ModelProto): First model m2 (ModelProto): Second model io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...] representing outputs of the first graph and inputs of the second to be connected inputs (list of string): Optional list of inputs to be included in the combined graph By default, all inputs not present in the ``io_map`` argument will be included in the combined model outputs (list of string): Optional list of outputs to be included in the combined graph By default, all outputs not present in the ``io_map`` argument will be included in the combined model prefix1 (string): Optional prefix to be added to all names in m1 prefix2 (string): Optional prefix to be added to all names in m2 name (string): Optional name for the combined graph By default, the name is g1.name and g2.name concatenated with an undescore delimiter doc_string (string): Optional docstring for the combined graph If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used producer_name (string): Optional producer name for the combined model. Default: 'onnx.compose' producer_version (string): Optional producer version for the combined model. Default: "1.0" domain (string): Optional domain of the combined model. Default: "" model_version (int): Optional version of the graph encoded. Default: 1 Returns: ModelProto """ if type(m1) is not ModelProto: raise ValueError("m1 argument is not an ONNX model") if type(m2) is not ModelProto: raise ValueError("m2 argument is not an ONNX model") if m1.ir_version != m2.ir_version: raise ValueError( f"IR version mismatch {m1.ir_version} != {m2.ir_version}." " Both models should have the same IR version" ) ir_version = m1.ir_version opset_import_map: MutableMapping[str, int] = {} opset_imports = [entry for entry in m1.opset_import] + [ entry for entry in m2.opset_import ] for entry in opset_imports: if entry.domain in opset_import_map: found_version = opset_import_map[entry.domain] if entry.version != found_version: raise ValueError( "Can't merge two models with different operator set ids for a given domain. " f"Got: {m1.opset_import} and {m2.opset_import}" ) else: opset_import_map[entry.domain] = entry.version # Prefixing names in the graph if requested, adjusting io_map accordingly if prefix1 or prefix2: if prefix1: m1_copy = ModelProto() m1_copy.CopyFrom(m1) m1 = m1_copy m1 = add_prefix(m1, prefix=prefix1) if prefix2: m2_copy = ModelProto() m2_copy.CopyFrom(m2) m2 = m2_copy m2 = add_prefix(m2, prefix=prefix2) io_map = [ ( prefix1 + io[0] if prefix1 else io[0], prefix2 + io[1] if prefix2 else io[1], ) for io in io_map ] graph = merge_graphs( m1.graph, m2.graph, io_map, inputs=inputs, outputs=outputs, name=name, doc_string=doc_string, ) model = helper.make_model( graph, producer_name=producer_name, producer_version=producer_version, domain=domain, model_version=model_version, opset_imports=opset_imports, ir_version=ir_version, ) # Merging model metadata props model_props = {} for meta_entry in m1.metadata_props: model_props[meta_entry.key] = meta_entry.value for meta_entry in m2.metadata_props: if meta_entry.key in model_props: value = model_props[meta_entry.key] if value != meta_entry.value: raise ValueError( "Can't merge models with different values for the same model metadata property." f" Found: property = {meta_entry.key}, with values {value} and {meta_entry.value}." ) else: model_props[meta_entry.key] = meta_entry.value helper.set_model_props(model, model_props) # Merging functions function_overlap = list( {f.name for f in m1.functions} & {f.name for f in m2.functions} ) if function_overlap: raise ValueError( "Can't merge models with overlapping local function names." " Found in both graphs: " + ", ".join(function_overlap) ) model.functions.MergeFrom(m1.functions) model.functions.MergeFrom(m2.functions) checker.check_model(model) return model
[docs]def add_prefix_graph( graph: GraphProto, prefix: str, rename_nodes: Optional[bool] = True, rename_edges: Optional[bool] = True, rename_inputs: Optional[bool] = True, rename_outputs: Optional[bool] = True, rename_initializers: Optional[bool] = True, rename_value_infos: Optional[bool] = True, inplace: Optional[bool] = False, ) -> GraphProto: """Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs, initializers, sparse initializer, value infos. It can be used as a utility before merging graphs that have overlapping names. Empty names are not prefixed. Arguments: graph (GraphProto): Graph prefix (str): Prefix to be added to each name in the graph rename_nodes (bool): Whether to prefix node names rename_edges (bool): Whether to prefix node edge names rename_inputs (bool): Whether to prefix input names rename_outputs (bool): Whether to prefix output names rename_initializers (bool): Whether to prefix initializer and sparse initializer names rename_value_infos (bool): Whether to prefix value info names inplace (bool): If True, mutates the graph directly. Otherwise, a copy will be created Returns: GraphProto """ if type(graph) is not GraphProto: raise ValueError("graph argument is not an ONNX graph") if not inplace: g = GraphProto() g.CopyFrom(graph) else: g = graph def _prefixed(prefix: str, name: str) -> str: return prefix + name if len(name) > 0 else name name_map = {} if rename_edges: for n in g.node: for e in n.input: name_map[e] = _prefixed(prefix, e) for e in n.output: name_map[e] = _prefixed(prefix, e) else: if rename_outputs: for entry in g.output: name_map[entry.name] = _prefixed(prefix, entry.name) if rename_inputs: for entry in g.input: name_map[entry.name] = _prefixed(prefix, entry.name) if rename_nodes: for n in g.node: n.name = _prefixed(prefix, n.name) if rename_initializers: for init in g.initializer: name_map[init.name] = _prefixed(prefix, init.name) for sparse_init in g.sparse_initializer: name_map[sparse_init.values.name] = _prefixed( prefix, sparse_init.values.name ) name_map[sparse_init.indices.name] = _prefixed( prefix, sparse_init.indices.name ) if rename_value_infos: for entry in g.value_info: name_map[entry.name] = _prefixed(prefix, entry.name) for n in g.node: for i in range(len(n.output)): if n.output[i] in name_map: n.output[i] = name_map[n.output[i]] for i in range(len(n.input)): if n.input[i] in name_map: n.input[i] = name_map[n.input[i]] for in_desc in g.input: if in_desc.name in name_map: in_desc.name = name_map[in_desc.name] for out_desc in g.output: if out_desc.name in name_map: out_desc.name = name_map[out_desc.name] for initializer in g.initializer: if initializer.name in name_map: initializer.name = name_map[initializer.name] for sparse_initializer in g.sparse_initializer: if sparse_initializer.values.name in name_map: sparse_initializer.values.name = name_map[sparse_initializer.values.name] if sparse_initializer.indices.name in name_map: sparse_initializer.indices.name = name_map[sparse_initializer.indices.name] for value_info in g.value_info: if value_info.name in name_map: value_info.name = name_map[value_info.name] return g
[docs]def add_prefix( model: ModelProto, prefix: str, rename_nodes: Optional[bool] = True, rename_edges: Optional[bool] = True, rename_inputs: Optional[bool] = True, rename_outputs: Optional[bool] = True, rename_initializers: Optional[bool] = True, rename_value_infos: Optional[bool] = True, rename_functions: Optional[bool] = True, inplace: Optional[bool] = False, ) -> ModelProto: """Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs, initializers, sparse initializer, value infos, and local functions. It can be used as a utility before merging graphs that have overlapping names. Empty names are not _prefixed. Arguments: model (ModelProto): Model prefix (str): Prefix to be added to each name in the graph rename_nodes (bool): Whether to prefix node names rename_edges (bool): Whether to prefix node edge names rename_inputs (bool): Whether to prefix input names rename_outputs (bool): Whether to prefix output names rename_initializers (bool): Whether to prefix initializer and sparse initializer names rename_value_infos (bool): Whether to prefix value info nanes rename_functions (bool): Whether to prefix local function names inplace (bool): If True, mutates the model directly. Otherwise, a copy will be created Returns: ModelProto """ if type(model) is not ModelProto: raise ValueError("model argument is not an ONNX model") if not inplace: m = ModelProto() m.CopyFrom(model) model = m add_prefix_graph( model.graph, prefix, rename_nodes=rename_nodes, rename_edges=rename_edges, rename_inputs=rename_inputs, rename_outputs=rename_outputs, rename_initializers=rename_initializers, rename_value_infos=rename_value_infos, inplace=True, # No need to create a copy, since it's a new model ) if rename_functions: f_name_map = {} for f in model.functions: new_f_name = prefix + f.name f_name_map[f.name] = new_f_name f.name = new_f_name # Adjust references to local functions in other local function # definitions for f in model.functions: for n in f.node: if n.op_type in f_name_map: n.op_type = f_name_map[n.op_type] # Adjust references to local functions in the graph for n in model.graph.node: if n.op_type in f_name_map: n.op_type = f_name_map[n.op_type] return model
[docs]def expand_out_dim_graph( graph: GraphProto, dim_idx: int, inplace: Optional[bool] = False, ) -> GraphProto: """Inserts an extra dimension with extent 1 to each output in the graph. Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs, for example when the second one expects a batch dimension. Arguments: graph (GraphProto): Graph dim_idx (int): Index of the dimension to be inserted. A negative value means counting dimensions from the back. inplace (bool): If True, mutates the model directly. Otherwise, a copy will be created Returns: GraphProto """ if type(graph) is not GraphProto: raise ValueError("graph argument is not an ONNX graph") if not inplace: g = GraphProto() g.CopyFrom(graph) else: g = graph orig_out_names = [output.name for output in g.output] for n in g.node: for i in range(len(n.output)): if n.output[i] in orig_out_names: n.output[i] = n.output[i] + f"_collapsed_dim_{dim_idx}" for i in range(len(n.input)): if n.input[i] in orig_out_names: n.input[i] = n.input[i] + f"_collapsed_dim_{dim_idx}" expand_dim_k = g.name + "_expand_out_dim_idx" g.node.append( helper.make_node( "Constant", inputs=[], outputs=[expand_dim_k], name=f"{expand_dim_k}-constant", value=helper.make_tensor( name=f"{expand_dim_k}-value", data_type=tp.INT64, dims=[ 1, ], vals=[ dim_idx, ], ), ) ) for _ in range(len(g.output)): o = g.output.pop(0) prev_output = o.name + f"_collapsed_dim_{dim_idx}" g.node.append( helper.make_node( "Unsqueeze", inputs=[prev_output, expand_dim_k], outputs=[o.name], name=f"unsqueeze-{o.name}", ) ) new_shape = [d.dim_value for d in o.type.tensor_type.shape.dim] new_shape.insert(dim_idx, 1) g.output.append( helper.make_tensor_value_info( o.name, o.type.tensor_type.elem_type, new_shape ) ) return g
[docs]def expand_out_dim( model: ModelProto, dim_idx: int, inplace: Optional[bool] = False, ) -> ModelProto: """Inserts an extra dimension with extent 1 to each output in the graph. Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs, for example when the second one expects a batch dimension. Arguments: model (ModelProto): Model dim_idx (int): Index of the dimension to be inserted. A negative value means counting dimensions from the back. inplace (bool): If True, mutates the model directly. Otherwise, a copy will be created Returns: ModelProto """ if type(model) is not ModelProto: raise ValueError("model argument is not an ONNX model") if not inplace: m = ModelProto() m.CopyFrom(model) model = m expand_out_dim_graph( model.graph, dim_idx, inplace=True, # No need to create a copy, since it's a new model ) return model