Coverage for mlprodict/onnx_tools/onnx_tools.py: 89%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

93 statements  

1""" 

2@file 

3@brief Functions to manipulate ONNX file. 

4""" 

5from onnx import helper 

6 

7 

8def find_node_name(model, name): 

9 """ 

10 Finds a node by its name. 

11 :param model: onnx graph 

12 :param name: node name 

13 :return: node pointer 

14 """ 

15 if not hasattr(model, "graph"): 

16 raise TypeError( # pragma: no cover 

17 "Parameter model is not an ONNX model but " 

18 "{}".format(type(model))) 

19 for node in model.graph.node: 

20 if node.name == name: 

21 return node 

22 return None # pragma: no cover 

23 

24 

25def find_node_input_name(node, name): 

26 """ 

27 Finds a node input by its name. 

28 :param node: onnx node 

29 :param name: node name 

30 :return: input index 

31 """ 

32 for i, inode in enumerate(node.input.node): 

33 if inode.name == name: 

34 return i 

35 return -1 

36 

37 

38def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs): 

39 """ 

40 Inserts a node before one node input. 

41 :param model: onnx graph 

42 :param op_type: 

43 :param node: node or node name 

44 :param input_index: input index or input name 

45 :param attrs: node attributes 

46 :return: updated graph 

47 """ 

48 if isinstance(node, str): 

49 inode = find_node_name(model, node) 

50 else: 

51 inode = node 

52 if isinstance(input_index, str): 

53 input_index_ = find_node_input_name(node, input_index) 

54 if input_index_ == -1: 

55 raise RuntimeError( # pragma: no cover 

56 "Unable to find input_index %r in node %r." % ( 

57 input_index, node.name)) # pylint: disable=E1120 

58 input_index = input_index_ 

59 

60 # guess a new name 

61 names = [] 

62 for n in model.graph.node: 

63 names.extend(n.input) 

64 names.extend(n.output) 

65 names = set(names) 

66 if new_name is None: 

67 new_name = op_type.lower() 

68 root_name = new_name 

69 i = 0 

70 while new_name in names: 

71 new_name = "%s_%d" % (root_name, i) 

72 i += 1 

73 

74 new_node = helper.make_node( 

75 op_type, [inode.input[input_index]], [new_name], **attrs) 

76 inode.input[input_index] = new_name 

77 keep_nodes = list(model.graph.node) 

78 keep_nodes.append(new_node) 

79 keep_nodes = ensure_topological_order( 

80 model.graph.input, model.graph.initializer, keep_nodes) 

81 

82 graph = helper.make_graph( 

83 keep_nodes, model.graph.name, model.graph.input, 

84 model.graph.output, model.graph.initializer) 

85 onnx_model = helper.make_model(graph) 

86 onnx_model.ir_version = model.ir_version 

87 onnx_model.producer_name = model.producer_name 

88 onnx_model.producer_version = model.producer_version 

89 onnx_model.domain = model.domain 

90 onnx_model.model_version = model.model_version 

91 onnx_model.doc_string = model.doc_string 

92 if len(model.metadata_props) > 0: 

93 values = {p.key: p.value for p in model.metadata_props} 

94 helper.set_model_props(onnx_model, values) 

95 

96 del onnx_model.opset_import[:] # pylint: disable=E1101 

97 for oimp in model.opset_import: 

98 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

99 op_set.domain = oimp.domain 

100 op_set.version = oimp.version 

101 

102 if len(onnx_model.graph.input) != len(model.graph.input): # pylint: disable=E1101 

103 raise RuntimeError( # pragma: no cover 

104 "Input mismatch {} != {}".format( 

105 len(onnx_model.input), len(model.input))) # pylint: disable=E1101 

106 return onnx_model 

107 

108 

109def ensure_topological_order(inputs, initializers, nodes): 

110 """ 

111 Ensures and modifies the order of nodes to have 

112 a topological order (every node in the list 

113 can only be an input for a node later in this list). 

114 The function raises an exception if a cycle is detected. 

115 

116 :param inputs: graph inputs: 

117 :param initializers: graph initializers 

118 :param nodes: graph nodes 

119 :return: list ordered nodes 

120 """ 

121 order = {} 

122 for inp in inputs: 

123 name = inp.name 

124 order[name] = 0 

125 for inp in initializers: 

126 name = inp.name 

127 order[name] = 0 

128 n_iter = 0 

129 while n_iter < len(nodes) * 2: 

130 n_iter += 1 

131 missing_names = set() 

132 missing_ops = [] 

133 for node in nodes: 

134 maxi = 0 

135 for name in node.input: 

136 if name in order: 

137 maxi = max(maxi, order[name]) 

138 else: 

139 maxi = None 

140 missing_names.add(name) 

141 break 

142 if maxi is None: 

143 missing_ops.append(node) 

144 continue 

145 key = id(node) 

146 if key in order: 

147 continue 

148 maxi += 1 

149 order[key] = maxi 

150 maxi += 1 

151 for name in node.output: 

152 if name in order: 

153 raise RuntimeError( # pragma: no cover 

154 "Unable to sort a node (cycle). An output was " 

155 "already ordered %r (iteration=%r)." % ( 

156 name, n_iter)) 

157 order[name] = maxi 

158 if len(missing_names) == 0: 

159 continue 

160 

161 if len(missing_ops) > 0: # pragma: no cover 

162 def nstr(name): 

163 if name in order: 

164 return "%s#%d" % (name, order[name]) 

165 return name 

166 rows = ["%s(%s) -> [%s]" % ( 

167 n.name or n.op_type, 

168 ', '.join(map(nstr, n.input)), 

169 ', '.join(n.output)) 

170 for n in missing_ops] 

171 rows.insert(0, "") 

172 rows.append("--") 

173 rows.append("--all-nodes--") 

174 rows.append("--") 

175 rows.extend("%s(%s) -> [%s]" % ( 

176 n.name or n.op_type, 

177 ', '.join(map(nstr, n.input)), 

178 ', '.join(n.output)) 

179 for n in nodes) 

180 raise RuntimeError( 

181 "After %d iterations for %d nodes, still unable " 

182 "to sort names %r. The graph may be disconnected. " 

183 "List of operators: %s" % ( 

184 n_iter, len(nodes), missing_names, 

185 "\n".join(rows))) 

186 

187 # Update order 

188 topo = [(order[id(node)], str(id(node))) for node in nodes] 

189 topo.sort() 

190 map_nodes = {str(id(node)): node for node in nodes} 

191 return [map_nodes[_[1]] for _ in topo]