Coverage for mlprodict/onnx_tools/optim/onnx_helper.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Statistics on :epkg:`ONNX` models.
4"""
5from collections import Counter
6from onnx.helper import make_graph
7from ..onnx2py_helper import from_pb, make_value_info
8from ._onnx_optimisation_common import _apply_optimisation_on_graph
9from .onnx_optimisation import onnx_remove_node
12def onnx_statistics(onnx_model, recursive=True, optim=True, node_type=False):
13 """
14 Computes statistics on :epkg:`ONNX` models,
15 extracts informations about the model such as
16 the number of nodes.
18 :param onnx_model: onnx model
19 :param recursive: looks into subgraphs
20 :param optim: adds statistics because of optimisation
21 :param node_type: add distribution of node types
22 :return: dictionary
24 .. runpython::
25 :showcode:
26 :warningout: DeprecationWarning
28 import pprint
29 from sklearn.linear_model import LogisticRegression
30 from sklearn.ensemble import RandomForestClassifier
31 from sklearn.datasets import load_iris
32 from mlprodict.onnx_tools.optim.onnx_helper import onnx_statistics
33 from mlprodict.onnx_conv import to_onnx
35 iris = load_iris()
36 X = iris.data
37 y = iris.target
38 lr = LogisticRegression()
39 lr.fit(X, y)
40 onx = to_onnx(lr, X[:1])
41 pprint.pprint((lr, onnx_statistics(onx)))
43 iris = load_iris()
44 X = iris.data
45 y = iris.target
46 rf = RandomForestClassifier()
47 rf.fit(X, y)
48 onx = to_onnx(rf, X[:1], target_opset=12)
49 pprint.pprint((rf, onnx_statistics(onx)))
50 """
51 atts = ['doc_string', 'ir_version', 'metadata_props', 'domain',
52 'model_version', 'producer_name', 'producer_version']
54 def update(sts, st):
55 for k, v in st.items():
56 if k in ['size'] or k in atts:
57 continue # pragma: no cover
58 if k in sts:
59 sts[k] += v
60 else:
61 sts[k] = v
63 if hasattr(onnx_model, 'graph'):
64 content = onnx_model.SerializeToString()
65 nnodes = len(onnx_model.graph.node)
66 ninits = len(onnx_model.graph.initializer)
67 stats = {'size': len(content), 'nnodes': nnodes, 'ninits': ninits}
68 for a in atts:
69 v = getattr(onnx_model, a)
70 if isinstance(v, str):
71 li = None
72 else:
73 try:
74 li = list(v)
75 except TypeError:
76 li = None
77 if li is not None and len(li) == 0:
78 continue
79 stats[a] = v
81 for opi in onnx_model.opset_import:
82 stats[opi.domain] = opi.version
84 graph = onnx_model.graph
85 elif not hasattr(onnx_model, 'node'): # pragma: no cover
86 # We're in a node.
87 stats = {'nnodes': 1}
88 if hasattr(onnx_model, 'attribute') and onnx_model.attribute:
89 for att in onnx_model.attribute:
90 if att.name == 'body':
91 st = onnx_statistics(att.g)
92 update(stats, st)
93 return stats
94 else:
95 graph = onnx_model
96 nnodes = len(graph.node)
97 stats = {'nnodes': nnodes}
99 # Number of identities
100 counts = Counter(map(lambda obj: obj.op_type, graph.node))
101 if node_type:
102 for op, v in counts.items():
103 stats['op_' + op] = v
104 else:
105 for op in ['Cast', 'Identity', 'ZipMap', 'Reshape']:
106 if op in counts:
107 stats['op_' + op] = counts[op]
109 # Recursive
110 if recursive:
111 for node in graph.node:
112 if not hasattr(node, 'attribute'):
113 continue # pragma: no cover
114 for att in node.attribute:
115 if att.name != 'body':
116 continue
117 substats = onnx_statistics(
118 att.g, recursive=True, optim=False, node_type=node_type)
119 update(stats, {'subgraphs': 1})
120 update(stats, substats)
122 # optimisation: remove_identity nodes
123 if optim:
124 new_model = onnx_remove_node(
125 onnx_model, recursive=recursive)
126 st = onnx_statistics(new_model, recursive=recursive, optim=False)
127 for key in ["op_Identity", "subgraphs", "size",
128 "nnodes", "ninits"]:
129 if key in st:
130 stats[key + "_optim"] = st[key]
131 return stats
134def change_input_first_dimension(onnx_model, N=None, debug_info=None):
135 """
136 Some models are converted under the assumption
137 batch prediction is not necessary. This function
138 changes the first dimension of an ONNX graph.
140 @param onnx_model model :epkg:`onnx`
141 @param N new first dimension,
142 None to avoid changing it,
143 0 to fix an undefined
144 first dimension
145 @param debug_info unused
146 @return modified model onnx
147 """
148 if hasattr(onnx_model, 'graph'):
149 return _apply_optimisation_on_graph(
150 change_input_first_dimension, onnx_model, N=N)
152 graph = onnx_model
154 nodes = graph.node
155 inputs = [from_pb(input) for input in onnx_model.input]
156 outputs = onnx_model.output
158 if N <= 0:
159 N = None
160 for input in inputs:
161 input[2][0] = N
162 inputs = [make_value_info(*v) for v in inputs]
164 graph = make_graph(nodes, onnx_model.name,
165 inputs, outputs, onnx_model.initializer)
167 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
168 return graph