Coverage for mlprodict/npy/xop_opset.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=E0602
2"""
3@file
4@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`.
6.. versionadded:: 0.9
7"""
8import numpy
9from .xop import loadop
12def OnnxReduceSumApi11(*x, axes=None, keepdims=1, op_version=None,
13 output_names=None):
14 """
15 Adds operator ReduceSum with opset>=13 following API from opset 12.
16 """
17 if op_version is None:
18 raise RuntimeError( # pragma: no cover
19 "op_version must be specified.")
20 if op_version is None or op_version >= 13:
21 OnnxReduceSum = loadop('ReduceSum')
22 if axes is None:
23 return OnnxReduceSum(
24 *x, keepdims=keepdims, op_version=op_version,
25 output_names=output_names)
26 return OnnxReduceSum(
27 *x, numpy.array(axes, dtype=numpy.int64),
28 keepdims=keepdims, op_version=op_version,
29 output_names=output_names)
30 if op_version >= 11:
31 OnnxReduceSum_11 = loadop('ReduceSum_11')
32 if axes is None:
33 return OnnxReduceSum_11(
34 *x, keepdims=keepdims,
35 op_version=op_version, output_names=output_names)
36 return OnnxReduceSum_11(
37 *x, axes=axes, keepdims=keepdims,
38 op_version=op_version, output_names=output_names)
39 OnnxReduceSum_1 = loadop('ReduceSum_1')
40 if axes is None:
41 return OnnxReduceSum_1(*x, keepdims=keepdims,
42 op_version=op_version,
43 output_names=output_names)
44 return OnnxReduceSum_1(*x, axes=axes, keepdims=keepdims,
45 op_version=op_version, output_names=output_names)
48def OnnxSplitApi11(*x, axis=0, split=None, op_version=None,
49 output_names=None):
50 """
51 Adds operator Split with opset>=13 following API from opset 11.
52 """
53 if op_version is None:
54 raise RuntimeError( # pragma: no cover
55 "op_version must be specified.")
56 if op_version is None or op_version >= 13:
57 OnnxSplit = loadop('Split')
58 if split is None:
59 return OnnxSplit(
60 *x, axis=axis, op_version=op_version,
61 output_names=output_names)
62 return OnnxSplit(
63 *x, numpy.array(split, dtype=numpy.int64), axis=axis,
64 op_version=op_version, output_names=output_names)
65 if op_version >= 11:
66 OnnxSplit_11 = loadop('Split_11')
67 if split is None:
68 return OnnxSplit_11(
69 *x, axis=axis, op_version=op_version,
70 output_names=output_names)
71 return OnnxSplit_11(
72 *x, split=split, axis=axis, op_version=op_version,
73 output_names=output_names)
74 OnnxSplit_2 = loadop('Split_2')
75 if split is None:
76 return OnnxSplit_2(
77 *x, axis=axis, op_version=op_version, output_names=output_names)
78 return OnnxSplit_2(*x, split=split, axis=axis,
79 op_version=op_version, output_names=output_names)
82def OnnxSqueezeApi11(*x, axes=None, op_version=None,
83 output_names=None):
84 """
85 Adds operator Squeeze with opset>=13 following API from opset 11.
86 """
87 if op_version is None:
88 raise RuntimeError( # pragma: no cover
89 "op_version must be specified.")
90 if op_version is None or op_version >= 13:
91 OnnxSqueeze = loadop('Squeeze')
92 return OnnxSqueeze(
93 *x, numpy.array(axes, dtype=numpy.int64),
94 op_version=op_version, output_names=output_names)
95 if op_version >= 11:
96 OnnxSqueeze_11 = loadop('Squeeze_11')
97 return OnnxSqueeze_11(
98 *x, axes=axes, op_version=op_version,
99 output_names=output_names)
100 OnnxSqueeze_1 = loadop('Squeeze_1')
101 return OnnxSqueeze_1(*x, axes=axes,
102 op_version=op_version, output_names=output_names)
105def OnnxUnsqueezeApi11(*x, axes=None, op_version=None,
106 output_names=None):
107 """
108 Adds operator Unsqueeze with opset>=13 following API from opset 11.
109 """
110 if op_version is None:
111 raise RuntimeError( # pragma: no cover
112 "op_version must be specified.")
113 if op_version is None or op_version >= 13:
114 OnnxUnsqueeze = loadop('Unsqueeze')
115 return OnnxUnsqueeze(
116 *x, numpy.array(axes, dtype=numpy.int64),
117 op_version=op_version, output_names=output_names)
118 if op_version >= 11:
119 OnnxUnsqueeze_11 = loadop('Unsqueeze_11')
120 return OnnxUnsqueeze_11(
121 *x, axes=axes, op_version=op_version,
122 output_names=output_names)
123 OnnxUnsqueeze_1 = loadop('Unsqueeze_1')
124 return OnnxUnsqueeze_1(*x, axes=axes,
125 op_version=op_version, output_names=output_names)
128def OnnxReduceL2_typed(dtype, x, axes=None, keepdims=1, op_version=None,
129 output_names=None):
130 """
131 Adds operator ReduceL2 for float or double.
132 """
133 OnnxMul, OnnxSqrt = loadop('Mul', 'Sqrt')
134 if dtype == numpy.float32:
135 OnnxReduceL2 = loadop('ReduceL2')
136 return OnnxReduceL2(
137 x, axes=axes, keepdims=keepdims,
138 op_version=op_version, output_names=output_names)
139 x2 = OnnxMul(x, x, op_version=op_version)
140 red = OnnxReduceSumApi11(
141 x2, axes=[1], keepdims=1, op_version=op_version)
142 return OnnxSqrt(
143 red, op_version=op_version, output_names=output_names)
146def OnnxReshapeApi13(*x, allowzero=0, op_version=None,
147 output_names=None):
148 """
149 Adds operator Reshape with opset>=14 following API from opset 13.
150 """
151 if op_version is None:
152 raise RuntimeError( # pragma: no cover
153 "op_version must be specified.")
154 if op_version is None or op_version >= 14:
155 OnnxReshape = loadop('Reshape')
156 return OnnxReshape(
157 *x, allowzero=allowzero,
158 op_version=op_version, output_names=output_names)
159 if op_version >= 13:
160 OnnxReshape_13 = loadop('Reshape_13')
161 return OnnxReshape_13(
162 *x, op_version=op_version, output_names=output_names)
163 OnnxReshape_5 = loadop('Reshape_5')
164 return OnnxReshape_5(
165 *x, op_version=op_version, output_names=output_names)