Coverage for mlprodict/npy/xop_opset.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

73 statements  

1# pylint: disable=E0602 

2""" 

3@file 

4@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`. 

5 

6.. versionadded:: 0.9 

7""" 

8import numpy 

9from .xop import loadop 

10 

11 

12def OnnxReduceSumApi11(*x, axes=None, keepdims=1, op_version=None, 

13 output_names=None): 

14 """ 

15 Adds operator ReduceSum with opset>=13 following API from opset 12. 

16 """ 

17 if op_version is None: 

18 raise RuntimeError( # pragma: no cover 

19 "op_version must be specified.") 

20 if op_version is None or op_version >= 13: 

21 OnnxReduceSum = loadop('ReduceSum') 

22 if axes is None: 

23 return OnnxReduceSum( 

24 *x, keepdims=keepdims, op_version=op_version, 

25 output_names=output_names) 

26 return OnnxReduceSum( 

27 *x, numpy.array(axes, dtype=numpy.int64), 

28 keepdims=keepdims, op_version=op_version, 

29 output_names=output_names) 

30 if op_version >= 11: 

31 OnnxReduceSum_11 = loadop('ReduceSum_11') 

32 if axes is None: 

33 return OnnxReduceSum_11( 

34 *x, keepdims=keepdims, 

35 op_version=op_version, output_names=output_names) 

36 return OnnxReduceSum_11( 

37 *x, axes=axes, keepdims=keepdims, 

38 op_version=op_version, output_names=output_names) 

39 OnnxReduceSum_1 = loadop('ReduceSum_1') 

40 if axes is None: 

41 return OnnxReduceSum_1(*x, keepdims=keepdims, 

42 op_version=op_version, 

43 output_names=output_names) 

44 return OnnxReduceSum_1(*x, axes=axes, keepdims=keepdims, 

45 op_version=op_version, output_names=output_names) 

46 

47 

48def OnnxSplitApi11(*x, axis=0, split=None, op_version=None, 

49 output_names=None): 

50 """ 

51 Adds operator Split with opset>=13 following API from opset 11. 

52 """ 

53 if op_version is None: 

54 raise RuntimeError( # pragma: no cover 

55 "op_version must be specified.") 

56 if op_version is None or op_version >= 13: 

57 OnnxSplit = loadop('Split') 

58 if split is None: 

59 return OnnxSplit( 

60 *x, axis=axis, op_version=op_version, 

61 output_names=output_names) 

62 return OnnxSplit( 

63 *x, numpy.array(split, dtype=numpy.int64), axis=axis, 

64 op_version=op_version, output_names=output_names) 

65 if op_version >= 11: 

66 OnnxSplit_11 = loadop('Split_11') 

67 if split is None: 

68 return OnnxSplit_11( 

69 *x, axis=axis, op_version=op_version, 

70 output_names=output_names) 

71 return OnnxSplit_11( 

72 *x, split=split, axis=axis, op_version=op_version, 

73 output_names=output_names) 

74 OnnxSplit_2 = loadop('Split_2') 

75 if split is None: 

76 return OnnxSplit_2( 

77 *x, axis=axis, op_version=op_version, output_names=output_names) 

78 return OnnxSplit_2(*x, split=split, axis=axis, 

79 op_version=op_version, output_names=output_names) 

80 

81 

82def OnnxSqueezeApi11(*x, axes=None, op_version=None, 

83 output_names=None): 

84 """ 

85 Adds operator Squeeze with opset>=13 following API from opset 11. 

86 """ 

87 if op_version is None: 

88 raise RuntimeError( # pragma: no cover 

89 "op_version must be specified.") 

90 if op_version is None or op_version >= 13: 

91 OnnxSqueeze = loadop('Squeeze') 

92 return OnnxSqueeze( 

93 *x, numpy.array(axes, dtype=numpy.int64), 

94 op_version=op_version, output_names=output_names) 

95 if op_version >= 11: 

96 OnnxSqueeze_11 = loadop('Squeeze_11') 

97 return OnnxSqueeze_11( 

98 *x, axes=axes, op_version=op_version, 

99 output_names=output_names) 

100 OnnxSqueeze_1 = loadop('Squeeze_1') 

101 return OnnxSqueeze_1(*x, axes=axes, 

102 op_version=op_version, output_names=output_names) 

103 

104 

105def OnnxUnsqueezeApi11(*x, axes=None, op_version=None, 

106 output_names=None): 

107 """ 

108 Adds operator Unsqueeze with opset>=13 following API from opset 11. 

109 """ 

110 if op_version is None: 

111 raise RuntimeError( # pragma: no cover 

112 "op_version must be specified.") 

113 if op_version is None or op_version >= 13: 

114 OnnxUnsqueeze = loadop('Unsqueeze') 

115 return OnnxUnsqueeze( 

116 *x, numpy.array(axes, dtype=numpy.int64), 

117 op_version=op_version, output_names=output_names) 

118 if op_version >= 11: 

119 OnnxUnsqueeze_11 = loadop('Unsqueeze_11') 

120 return OnnxUnsqueeze_11( 

121 *x, axes=axes, op_version=op_version, 

122 output_names=output_names) 

123 OnnxUnsqueeze_1 = loadop('Unsqueeze_1') 

124 return OnnxUnsqueeze_1(*x, axes=axes, 

125 op_version=op_version, output_names=output_names) 

126 

127 

128def OnnxReduceL2_typed(dtype, x, axes=None, keepdims=1, op_version=None, 

129 output_names=None): 

130 """ 

131 Adds operator ReduceL2 for float or double. 

132 """ 

133 OnnxMul, OnnxSqrt = loadop('Mul', 'Sqrt') 

134 if dtype == numpy.float32: 

135 OnnxReduceL2 = loadop('ReduceL2') 

136 return OnnxReduceL2( 

137 x, axes=axes, keepdims=keepdims, 

138 op_version=op_version, output_names=output_names) 

139 x2 = OnnxMul(x, x, op_version=op_version) 

140 red = OnnxReduceSumApi11( 

141 x2, axes=[1], keepdims=1, op_version=op_version) 

142 return OnnxSqrt( 

143 red, op_version=op_version, output_names=output_names) 

144 

145 

146def OnnxReshapeApi13(*x, allowzero=0, op_version=None, 

147 output_names=None): 

148 """ 

149 Adds operator Reshape with opset>=14 following API from opset 13. 

150 """ 

151 if op_version is None: 

152 raise RuntimeError( # pragma: no cover 

153 "op_version must be specified.") 

154 if op_version is None or op_version >= 14: 

155 OnnxReshape = loadop('Reshape') 

156 return OnnxReshape( 

157 *x, allowzero=allowzero, 

158 op_version=op_version, output_names=output_names) 

159 if op_version >= 13: 

160 OnnxReshape_13 = loadop('Reshape_13') 

161 return OnnxReshape_13( 

162 *x, op_version=op_version, output_names=output_names) 

163 OnnxReshape_5 = loadop('Reshape_5') 

164 return OnnxReshape_5( 

165 *x, op_version=op_version, output_names=output_names)