Coverage for onnxcustom/utils/onnx_helper.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

75 statements  

1# pylint: disable=C0415,E0611,E1101 

2""" 

3@file 

4@brief Onnx implementation of common functions used to train a model. 

5""" 

6import math 

7import numpy 

8from onnx import TensorProto, numpy_helper, helper 

9from onnxruntime import OrtValue 

10from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue 

11 

12 

13def onnx_rename_weights(onx): 

14 """ 

15 Renames ONNX initializers to make sure their name 

16 follows the alphabetical order. The model is 

17 modified inplace. This function calls 

18 :func:`onnx_rename_names 

19 <mlprodict.onnx_tools.onnx_manipulations.onnx_rename_names>`. 

20 

21 :param onx: ONNX model 

22 :return: same model 

23 

24 .. note:: 

25 The function does not go into subgraphs. 

26 """ 

27 from mlprodict.onnx_tools.onnx_manipulations import ( # pylint: disable=C0415 

28 onnx_rename_names) 

29 

30 init = [init.name for init in onx.graph.initializer] 

31 ninit = max(1, int(math.log(len(init)) / math.log(10) + 1)) 

32 fmt = "I%0{}d_%s".format(ninit) 

33 new_names = [fmt % (i, name) for i, name in enumerate(init)] 

34 repl = dict(zip(init, new_names)) 

35 return onnx_rename_names(onx, recursive=False, replace=repl) 

36 

37 

38def get_onnx_opset(onx, domain=''): 

39 """ 

40 Returns the opset associated to an opset. 

41 

42 :param onx: onx graph 

43 :param domain: domain 

44 :return: value 

45 """ 

46 for opset in onx.opset_import: 

47 if opset.domain == domain: 

48 return opset.version 

49 raise ValueError( 

50 "Unable to find opset for domain=%r." % domain) 

51 

52 

53def proto_type_to_dtype(proto_type): 

54 """ 

55 Converts a ONNX TensorProto type into numpy type. 

56 

57 :param proto_type: integer 

58 :return: proto type 

59 """ 

60 if proto_type == TensorProto.FLOAT: 

61 return numpy.float32 

62 if proto_type == TensorProto.DOUBLE: 

63 return numpy.float64 

64 # Not efficient. 

65 if proto_type == 'tensor(float)': 

66 return numpy.float32 

67 if proto_type == 'tensor(double)': 

68 return numpy.float64 

69 raise ValueError( 

70 "Unexpected value proto_type=%r (type=%r)." % ( 

71 proto_type, type(proto_type))) 

72 

73 

74def dtype_to_var_type(dtype): 

75 """ 

76 Converts a numpy dtype into a var type. 

77 """ 

78 from skl2onnx.common.data_types import ( 

79 FloatTensorType, DoubleTensorType, 

80 Int32TensorType, Int64TensorType) 

81 if dtype == numpy.float32: 

82 return FloatTensorType 

83 if dtype == numpy.float64: 

84 return DoubleTensorType 

85 if dtype == numpy.int64: 

86 return Int64TensorType 

87 if dtype == numpy.int32: 

88 return Int32TensorType 

89 raise ValueError( 

90 "Unexpected value dtype=%r." % dtype) 

91 

92 

93def _finalize_new_onnx(graph, onx): 

94 onnx_model = helper.make_model(graph) 

95 onnx_model.ir_version = onx.ir_version 

96 onnx_model.producer_name = onx.producer_name 

97 onnx_model.producer_version = onx.producer_version 

98 onnx_model.domain = onx.domain 

99 onnx_model.model_version = onx.model_version 

100 onnx_model.doc_string = onx.doc_string 

101 if len(onx.metadata_props) > 0: # pragma: no cover 

102 values = {p.key: p.value for p in onx.metadata_props} 

103 helper.set_model_props(onnx_model, values) 

104 

105 del onnx_model.opset_import[:] # pylint: disable=E1101 

106 for oimp in onx.opset_import: 

107 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

108 op_set.domain = oimp.domain 

109 op_set.version = oimp.version 

110 return onnx_model 

111 

112 

113def add_initializer(model, name, value): 

114 """ 

115 Adds an initializer to graph. 

116 

117 :param model: onnx model 

118 :param name: initializer name 

119 :param value: value 

120 :return: new ONNX graph 

121 """ 

122 inits = set(i.name for i in model.graph.initializer) 

123 if name in inits: 

124 raise ValueError( # pragma: no cover 

125 "Name %r is already taken among %r." % ( 

126 name, inits)) 

127 list_inits = list(model.graph.initializer) 

128 list_inits.append( 

129 numpy_helper.from_array(value, name=name)) 

130 graph_def = helper.make_graph( 

131 model.graph.node, model.graph.name, 

132 model.graph.input, model.graph.output, 

133 list_inits) 

134 return _finalize_new_onnx(graph_def, model) 

135 

136 

137def replace_initializers_into_onnx(model, results): 

138 """ 

139 Replaces initializers by other initializers, 

140 usually trained ones. 

141 

142 :param model: onnx graph 

143 :param results: results to be added in a dictionary 

144 :return: new onnx graph 

145 """ 

146 inputs = list(model.graph.input) 

147 outputs = list(model.graph.output) 

148 inits = list(model.graph.initializer) 

149 

150 inits_dict = {init.name: i for i, init in enumerate(inits)} 

151 for k, v in results.items(): 

152 if k in inits_dict: 

153 if isinstance(v, numpy.ndarray): 

154 v = numpy_helper.from_array(v, k) 

155 elif isinstance(v, (C_OrtValue, OrtValue)): 

156 v = numpy_helper.from_array(v.numpy(), k) 

157 inits[inits_dict[k]] = v 

158 else: 

159 raise RuntimeError( # pragma: no cover 

160 "Unable to find initializer %r in " 

161 "%r." % (k, inits_dict)) 

162 

163 graph = helper.make_graph( 

164 list(model.graph.node), model.graph.name, inputs, 

165 outputs, inits) 

166 return _finalize_new_onnx(graph, model)