Source code for onnx.tools.net_drawer

# SPDX-License-Identifier: Apache-2.0

# A library and utility for drawing ONNX nets. Most of this implementation has
# been borrowed from the caffe2 implementation
# https://github.com/caffe2/caffe2/blob/master/caffe2/python/net_drawer.py
#
# The script takes two required arguments:
#   -input: a path to a serialized ModelProto .pb file.
#   -output: a path to write a dot file representation of the graph
#
# Given this dot file representation, you can-for example-export this to svg
# with the graphviz `dot` utility, like so:
#
#   $ dot -Tsvg my_output.dot -o my_output.svg
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
from collections import defaultdict
import json
from onnx import ModelProto, GraphProto, NodeProto
import pydot  # type: ignore
from typing import Text, Any, Callable, Optional, Dict


OP_STYLE = {
    'shape': 'box',
    'color': '#0F9D58',
    'style': 'filled',
    'fontcolor': '#FFFFFF'
}

BLOB_STYLE = {'shape': 'octagon'}

_NodeProducer = Callable[[NodeProto, int], pydot.Node]


def _escape_label(name: Text) -> Text:
    # json.dumps is poor man's escaping
    return json.dumps(name)


def _form_and_sanitize_docstring(s: Text) -> Text:
    url = 'javascript:alert('
    url += _escape_label(s).replace('"', '\'').replace('<', '').replace('>', '')
    url += ')'
    return url


[docs]def GetOpNodeProducer(embed_docstring: bool = False, **kwargs: Any) -> _NodeProducer: def ReallyGetOpNode(op: NodeProto, op_id: int) -> pydot.Node: if op.name: node_name = '%s/%s (op#%d)' % (op.name, op.op_type, op_id) else: node_name = '%s (op#%d)' % (op.op_type, op_id) for i, input in enumerate(op.input): node_name += '\n input' + str(i) + ' ' + input for i, output in enumerate(op.output): node_name += '\n output' + str(i) + ' ' + output node = pydot.Node(node_name, **kwargs) if embed_docstring: url = _form_and_sanitize_docstring(op.doc_string) node.set_URL(url) return node return ReallyGetOpNode
[docs]def GetPydotGraph( graph: GraphProto, name: Optional[Text] = None, rankdir: Text = 'LR', node_producer: Optional[_NodeProducer] = None, embed_docstring: bool = False, ) -> pydot.Dot: if node_producer is None: node_producer = GetOpNodeProducer(embed_docstring=embed_docstring, **OP_STYLE) pydot_graph = pydot.Dot(name, rankdir=rankdir) pydot_nodes: Dict[Text, pydot.Node] = {} pydot_node_counts: Dict[Text, int] = defaultdict(int) for op_id, op in enumerate(graph.node): op_node = node_producer(op, op_id) pydot_graph.add_node(op_node) for input_name in op.input: if input_name not in pydot_nodes: input_node = pydot.Node( _escape_label( input_name + str(pydot_node_counts[input_name])), label=_escape_label(input_name), **BLOB_STYLE ) pydot_nodes[input_name] = input_node else: input_node = pydot_nodes[input_name] pydot_graph.add_node(input_node) pydot_graph.add_edge(pydot.Edge(input_node, op_node)) for output_name in op.output: if output_name in pydot_nodes: pydot_node_counts[output_name] += 1 output_node = pydot.Node( _escape_label( output_name + str(pydot_node_counts[output_name])), label=_escape_label(output_name), **BLOB_STYLE ) pydot_nodes[output_name] = output_node pydot_graph.add_node(output_node) pydot_graph.add_edge(pydot.Edge(op_node, output_node)) return pydot_graph
def main() -> None: parser = argparse.ArgumentParser(description="ONNX net drawer") parser.add_argument( "--input", type=Text, required=True, help="The input protobuf file.", ) parser.add_argument( "--output", type=Text, required=True, help="The output protobuf file.", ) parser.add_argument( "--rankdir", type=Text, default='LR', help="The rank direction of the pydot graph.", ) parser.add_argument( "--embed_docstring", action="store_true", help="Embed docstring as javascript alert. Useful for SVG format.", ) args = parser.parse_args() model = ModelProto() with open(args.input, 'rb') as fid: content = fid.read() model.ParseFromString(content) pydot_graph = GetPydotGraph( model.graph, name=model.graph.name, rankdir=args.rankdir, node_producer=GetOpNodeProducer( embed_docstring=args.embed_docstring, **OP_STYLE ), ) pydot_graph.write_dot(args.output) if __name__ == '__main__': main()