Coverage for mlprodict/onnx_tools/optim/onnx_helper.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

73 statements  

1""" 

2@file 

3@brief Statistics on :epkg:`ONNX` models. 

4""" 

5from collections import Counter 

6from onnx.helper import make_graph 

7from ..onnx2py_helper import from_pb, make_value_info 

8from ._onnx_optimisation_common import _apply_optimisation_on_graph 

9from .onnx_optimisation import onnx_remove_node 

10 

11 

12def onnx_statistics(onnx_model, recursive=True, optim=True, node_type=False): 

13 """ 

14 Computes statistics on :epkg:`ONNX` models, 

15 extracts informations about the model such as 

16 the number of nodes. 

17 

18 :param onnx_model: onnx model 

19 :param recursive: looks into subgraphs 

20 :param optim: adds statistics because of optimisation 

21 :param node_type: add distribution of node types 

22 :return: dictionary 

23 

24 .. runpython:: 

25 :showcode: 

26 :warningout: DeprecationWarning 

27 

28 import pprint 

29 from sklearn.linear_model import LogisticRegression 

30 from sklearn.ensemble import RandomForestClassifier 

31 from sklearn.datasets import load_iris 

32 from mlprodict.onnx_tools.optim.onnx_helper import onnx_statistics 

33 from mlprodict.onnx_conv import to_onnx 

34 

35 iris = load_iris() 

36 X = iris.data 

37 y = iris.target 

38 lr = LogisticRegression() 

39 lr.fit(X, y) 

40 onx = to_onnx(lr, X[:1]) 

41 pprint.pprint((lr, onnx_statistics(onx))) 

42 

43 iris = load_iris() 

44 X = iris.data 

45 y = iris.target 

46 rf = RandomForestClassifier() 

47 rf.fit(X, y) 

48 onx = to_onnx(rf, X[:1], target_opset=12) 

49 pprint.pprint((rf, onnx_statistics(onx))) 

50 """ 

51 atts = ['doc_string', 'ir_version', 'metadata_props', 'domain', 

52 'model_version', 'producer_name', 'producer_version'] 

53 

54 def update(sts, st): 

55 for k, v in st.items(): 

56 if k in ['size'] or k in atts: 

57 continue # pragma: no cover 

58 if k in sts: 

59 sts[k] += v 

60 else: 

61 sts[k] = v 

62 

63 if hasattr(onnx_model, 'graph'): 

64 content = onnx_model.SerializeToString() 

65 nnodes = len(onnx_model.graph.node) 

66 ninits = len(onnx_model.graph.initializer) 

67 stats = {'size': len(content), 'nnodes': nnodes, 'ninits': ninits} 

68 for a in atts: 

69 v = getattr(onnx_model, a) 

70 if isinstance(v, str): 

71 li = None 

72 else: 

73 try: 

74 li = list(v) 

75 except TypeError: 

76 li = None 

77 if li is not None and len(li) == 0: 

78 continue 

79 stats[a] = v 

80 

81 for opi in onnx_model.opset_import: 

82 stats[opi.domain] = opi.version 

83 

84 graph = onnx_model.graph 

85 elif not hasattr(onnx_model, 'node'): # pragma: no cover 

86 # We're in a node. 

87 stats = {'nnodes': 1} 

88 if hasattr(onnx_model, 'attribute') and onnx_model.attribute: 

89 for att in onnx_model.attribute: 

90 if att.name == 'body': 

91 st = onnx_statistics(att.g) 

92 update(stats, st) 

93 return stats 

94 else: 

95 graph = onnx_model 

96 nnodes = len(graph.node) 

97 stats = {'nnodes': nnodes} 

98 

99 # Number of identities 

100 counts = Counter(map(lambda obj: obj.op_type, graph.node)) 

101 if node_type: 

102 for op, v in counts.items(): 

103 stats['op_' + op] = v 

104 else: 

105 for op in ['Cast', 'Identity', 'ZipMap', 'Reshape']: 

106 if op in counts: 

107 stats['op_' + op] = counts[op] 

108 

109 # Recursive 

110 if recursive: 

111 for node in graph.node: 

112 if not hasattr(node, 'attribute'): 

113 continue # pragma: no cover 

114 for att in node.attribute: 

115 if att.name != 'body': 

116 continue 

117 substats = onnx_statistics( 

118 att.g, recursive=True, optim=False, node_type=node_type) 

119 update(stats, {'subgraphs': 1}) 

120 update(stats, substats) 

121 

122 # optimisation: remove_identity nodes 

123 if optim: 

124 new_model = onnx_remove_node( 

125 onnx_model, recursive=recursive) 

126 st = onnx_statistics(new_model, recursive=recursive, optim=False) 

127 for key in ["op_Identity", "subgraphs", "size", 

128 "nnodes", "ninits"]: 

129 if key in st: 

130 stats[key + "_optim"] = st[key] 

131 return stats 

132 

133 

134def change_input_first_dimension(onnx_model, N=None, debug_info=None): 

135 """ 

136 Some models are converted under the assumption 

137 batch prediction is not necessary. This function 

138 changes the first dimension of an ONNX graph. 

139 

140 @param onnx_model model :epkg:`onnx` 

141 @param N new first dimension, 

142 None to avoid changing it, 

143 0 to fix an undefined 

144 first dimension 

145 @param debug_info unused 

146 @return modified model onnx 

147 """ 

148 if hasattr(onnx_model, 'graph'): 

149 return _apply_optimisation_on_graph( 

150 change_input_first_dimension, onnx_model, N=N) 

151 

152 graph = onnx_model 

153 

154 nodes = graph.node 

155 inputs = [from_pb(input) for input in onnx_model.input] 

156 outputs = onnx_model.output 

157 

158 if N <= 0: 

159 N = None 

160 for input in inputs: 

161 input[2][0] = N 

162 inputs = [make_value_info(*v) for v in inputs] 

163 

164 graph = make_graph(nodes, onnx_model.name, 

165 inputs, outputs, onnx_model.initializer) 

166 

167 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 

168 return graph