Coverage for mlprodict/npy/numpy_onnx_impl.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief :epkg:`numpy` functions implemented with :epkg:`onnx`.
5.. versionadded:: 0.6
7.. versionchanged:: 0.7
8"""
9import warnings
10import numpy
11from onnx import onnx_pb as onnx_proto # pylint: disable=E1101
12from onnx.helper import make_tensor
13from .onnx_variable import OnnxVar, MultiOnnxVar as xtuple
14from .xop import loadop
15from .numpy_onnx_impl_body import if_then_else, OnnxVarGraph
18def abs(x):
19 "See :func:`numpy.abs`."
20 OnnxAbs = loadop('Abs')
21 return OnnxVar(x, op=OnnxAbs)
24def acos(x):
25 "See :func:`numpy.acos`."
26 OnnxAcos = loadop('Acos')
27 return OnnxVar(x, op=OnnxAcos)
30def acosh(x):
31 "See :func:`numpy.acosh`."
32 OnnxAcosh = loadop('Acosh')
33 return OnnxVar(x, op=OnnxAcosh)
36def amax(x, axis=None, keepdims=0):
37 "See :func:`numpy.amax`."
38 OnnxReduceMax = loadop('ReduceMax')
39 if axis is None:
40 return OnnxVar(x, op=OnnxReduceMax, keepdims=keepdims)
41 if not isinstance(axis, list):
42 axis = [axis]
43 return OnnxVar(x, op=OnnxReduceMax, keepdims=keepdims, axes=axis)
46def amin(x, axis=None, keepdims=0):
47 "See :func:`numpy.amin`."
48 OnnxReduceMin = loadop('ReduceMin')
49 if axis is None:
50 return OnnxVar(x, op=OnnxReduceMin, keepdims=keepdims)
51 if not isinstance(axis, list):
52 axis = [axis]
53 return OnnxVar(x, op=OnnxReduceMin, keepdims=keepdims, axes=axis)
56def arange(start, stop, step=1):
57 "See :func:`numpy.arange`, *start*, *stop* must be specified."
58 if not isinstance(step, (int, numpy.int64)):
59 raise TypeError( # pragma: no cover
60 "step must be an integer not %r." % type(step))
61 if isinstance(start, (int, numpy.int64, numpy.int32)):
62 start = numpy.array([start], dtype=numpy.int64)
63 zero = start == 0
64 else:
65 zero = False
66 if isinstance(stop, (int, numpy.int64, numpy.int32)):
67 stop = numpy.array([stop], dtype=numpy.int64)
68 value = make_tensor(
69 "value", onnx_proto.TensorProto.INT64, (1, ), [step]) # pylint: disable=E1101
71 OnnxAdd, OnnxCumSum, OnnxConstantOfShape, OnnxSub = loadop(
72 'Add', 'CumSum', 'ConstantOfShape', 'Sub')
73 if isinstance(step, (int, numpy.int64, numpy.int32)) and step == 1:
74 if zero:
75 shape = stop
76 else:
77 shape = stop - start
78 if isinstance(shape, OnnxVar):
79 shape = shape.reshape(numpy.array([-1], dtype=numpy.int64))
80 _cst = OnnxVar(shape, op=OnnxConstantOfShape, value=value)
81 cs = OnnxVar(_cst, numpy.array([0], dtype=numpy.int64),
82 op=OnnxCumSum)
83 diff = start - numpy.array([step], dtype=numpy.int64)
84 return OnnxVar(cs, diff, op=OnnxAdd)
86 if isinstance(step, (int, numpy.int64, numpy.int32)):
87 step = numpy.array([step], dtype=numpy.int64)
88 if zero:
89 shape = stop // step
90 else:
91 shape = (stop - start) // step
92 if isinstance(shape, OnnxVar):
93 shape = shape.reshape(numpy.array([-1], dtype=numpy.int64))
94 _cst = OnnxVar(shape, op=OnnxConstantOfShape, value=value)
95 else:
96 # csm = OnnxVar(_cst, step, op=OnnxMul)
97 raise NotImplementedError( # pragma: no cover
98 "Not yet implemented.")
100 cs = OnnxVar(_cst, numpy.array([0], dtype=numpy.int64),
101 op=OnnxCumSum)
102 add = OnnxVar(cs, start, op=OnnxAdd)
103 return OnnxVar(add, step, op=OnnxSub)
106def argmax(x, axis=0, keepdims=0):
107 """
108 See :func:`numpy.argmax`.
110 .. warning::
111 ONNX does not implement default value axis=None.
112 """
113 if axis is None:
114 raise NotImplementedError( # pragma: no cover
115 "ONNX does not allow axis=None.")
116 OnnxArgMax = loadop('ArgMax')
117 return OnnxVar(x, op=OnnxArgMax, axis=axis, keepdims=keepdims)
120def argmin(x, axis=0, keepdims=0):
121 """
122 See :func:`numpy.argmin`.
124 .. warning::
125 ONNX does not implement default value axis=None.
126 """
127 if axis is None:
128 raise NotImplementedError( # pragma: no cover
129 "ONNX does not allow axis=None.")
130 OnnxArgMin = loadop('ArgMin')
131 return OnnxVar(x, op=OnnxArgMin, axis=axis, keepdims=keepdims)
134def asin(x):
135 "See :func:`numpy.asin`."
136 OnnxAsin = loadop('Asin')
137 return OnnxVar(x, op=OnnxAsin)
140def asinh(x):
141 "See :func:`numpy.asinh`."
142 OnnxAsinh = loadop('Asinh')
143 return OnnxVar(x, op=OnnxAsinh)
146def atan(x):
147 "See :func:`numpy.atan`."
148 OnnxAtan = loadop('Atan')
149 return OnnxVar(x, op=OnnxAtan)
152def atanh(x):
153 "See :func:`numpy.atanh`."
154 OnnxAtanh = loadop('Atanh')
155 return OnnxVar(x, op=OnnxAtanh)
158def ceil(x):
159 "See :func:`numpy.ceil`."
160 OnnxCeil = loadop('Ceil')
161 return OnnxVar(x, op=OnnxCeil)
164def clip(x, a_min=None, a_max=None):
165 "See :func:`numpy.clip`."
166 args = [x]
167 if a_min is not None:
168 args.append(a_min)
169 if a_max is not None:
170 args.append(a_max)
171 OnnxClip = loadop('Clip')
172 return OnnxVar(*args, op=OnnxClip)
175def compress(condition, x, axis=None):
176 """
177 See :func:`numpy.compress`.
178 `numpy.compress(condition, x)` or `npnx.compress(x, condition)`.
179 """
180 OnnxCompress = loadop('Compress')
181 if axis is None:
182 return OnnxVar(x, condition, op=OnnxCompress)
183 return OnnxVar(x, condition, op=OnnxCompress, axis=axis)
186def cos(x):
187 "See :func:`numpy.cos`."
188 OnnxCos = loadop('Cos')
189 return OnnxVar(x, op=OnnxCos)
192def cosh(x):
193 "See :func:`numpy.cosh`."
194 OnnxCosh = loadop('Cosh')
195 return OnnxVar(x, op=OnnxCosh)
198def concat(*x, axis=0):
199 """
200 Operator concat, handle :func:`numpy.vstack` and
201 :func:`numpy.hstack`.
202 """
203 OnnxConcat = loadop('Concat')
204 if len(x) <= 1:
205 raise RuntimeError( # pragma: no cover
206 "N=%d<=1 elements to concatenate." % len(x))
207 return OnnxVar(*x, op=OnnxConcat, axis=axis)
210def cumsum(x, axis):
211 "See :func:`numpy.cumsum`."
212 OnnxCumSum = loadop('CumSum')
213 return OnnxVar(x, axis, op=OnnxCumSum)
216def cst(x, dtype=None):
217 """
218 Creates a constant. `log(x) + numpy.float32(1)` works
219 but `numpy.float32(32) + log(x)` fails because Python
220 calls `numpy.float32.__add__` instead of
221 `OnnxVar.__add__`. With this function, expression
222 `cst(1.) + log(x)` is valid. Parameter `dtype` is
223 used to overwrite the default dtype (`numpy.float32`
224 for floats and `numpy.int64` for ints.
225 """
226 OnnxIdentity = loadop('Identity')
227 if isinstance(x, float):
228 return OnnxVar(numpy.array([x], dtype=dtype or numpy.float32),
229 op=OnnxIdentity)
230 if isinstance(x, int):
231 return OnnxVar(numpy.array([x], dtype=dtype or numpy.int64),
232 op=OnnxIdentity)
233 if isinstance(x, numpy.ndarray):
234 return OnnxVar(x, op=OnnxIdentity)
235 if hasattr(x, 'dtype'):
236 if dtype is not None:
237 raise RuntimeError( # pragma: no cover
238 "dtype is not used because x is of type %r." % type(x))
239 return OnnxVar(numpy.array([x], dtype=x.dtype),
240 op=OnnxIdentity)
241 raise NotImplementedError( # pragma: no cover
242 "Unable to convert type %r into a constant." % type(x))
245def det(x):
246 "See :func:`numpy.linalg:det`."
247 OnnxDet = loadop('Det')
248 return OnnxVar(x, op=OnnxDet)
251def dot(a, b):
252 "See :func:`numpy.dot`"
253 warnings.warn(
254 "npnx.dot is equivalent to npnx.matmul == numpy.matmul "
255 "!= numpy.dot with arrays with more than 3D dimensions.")
256 OnnxMatMul = loadop('MatMul')
257 return OnnxVar(a, b, op=OnnxMatMul)
260def matmul(a, b):
261 "See :func:`numpy.matmul`."
262 OnnxMatMul = loadop('MatMul')
263 return OnnxVar(a, b, op=OnnxMatMul)
266def einsum(*x, equation=None):
267 "See :func:`numpy.einsum`."
268 OnnxEinsum = loadop('Einsum')
269 return OnnxVar(*x, op=OnnxEinsum, equation=equation)
272def erf(x):
273 "See :epkg:`scipy:special:erf`."
274 OnnxErf = loadop('Erf')
275 return OnnxVar(x, op=OnnxErf)
278def exp(x):
279 "See :func:`numpy.exp`."
280 OnnxExp = loadop('Exp')
281 return OnnxVar(x, op=OnnxExp)
284def expand_dims(x, axis):
285 "See :func:`numpy.expand_dims`."
286 if not isinstance(axis, int):
287 raise NotImplementedError( # pragma: no cover
288 "This function only allows integer for axis not %r." % type(axis))
289 OnnxUnsqueeze = loadop('Unsqueeze')
290 return OnnxVar(x, numpy.array([axis], dtype=numpy.int64),
291 op=OnnxUnsqueeze)
294def expit(x):
295 "See :epkg:`scipy:special:expit`."
296 OnnxSigmoid = loadop('Sigmoid')
297 return OnnxVar(x, op=OnnxSigmoid)
300def floor(x):
301 "See :func:`numpy.floor`."
302 OnnxFloor = loadop('Floor')
303 return OnnxVar(x, op=OnnxFloor)
306def hstack(*x):
307 "See :func:`numpy.hstack`."
308 if len(x) <= 1:
309 raise RuntimeError( # pragma: no cover
310 "N=%d<=1 elements to concatenate." % len(x))
311 OnnxConcat = loadop('Concat')
312 return OnnxVar(*x, op=OnnxConcat, axis=-1)
315def isnan(x):
316 "See :func:`numpy.isnan`."
317 OnnxIsNaN = loadop('IsNaN')
318 return OnnxVar(x, op=OnnxIsNaN)
321def identity(x):
322 "Identity."
323 OnnxIdentity = loadop('Identity')
324 return OnnxVar(x, op=OnnxIdentity)
327def log(x):
328 "See :func:`numpy.log`."
329 OnnxLog = loadop('Log')
330 return OnnxVar(x, op=OnnxLog)
333def log1p(x):
334 "See :func:`numpy.log1p`."
335 OnnxLog, OnnxAdd = loadop('Log', 'Add')
336 x1 = OnnxVar(x, numpy.array([1], dtype=x.dtype),
337 op=OnnxAdd)
338 return OnnxVar(x1, op=OnnxLog)
341def mean(x, axis=None, keepdims=0):
342 "See :func:`numpy.mean`."
343 OnnxReduceMean = loadop('ReduceMean')
344 if axis is None:
345 return OnnxVar(x, op=OnnxReduceMean, keepdims=keepdims)
346 if not isinstance(axis, list):
347 axis = [axis]
348 return OnnxVar(x, op=OnnxReduceMean, keepdims=keepdims, axes=axis)
351def onnx_if(condition, then_branch, else_branch):
352 """
353 Implements a test with onnx syntax.
355 :param condition: condition (@see cl OnnxVar)
356 :param then_branch: then branch, of type @see cl if_then_else
357 :param else_branch: else branch, of type @see cl if_then_else
358 :return: result (@see cl OnnxVar)
359 """
360 OnnxIf = loadop('If')
361 if isinstance(then_branch, numpy.ndarray):
362 then_branch = if_then_else(then_branch)
363 if not isinstance(then_branch, if_then_else):
364 raise TypeError(
365 "Parameter then_branch is not of type "
366 "'if_then_else' but %r." % type(then_branch))
367 if isinstance(else_branch, numpy.ndarray):
368 else_branch = if_then_else(else_branch)
369 if not isinstance(else_branch, if_then_else):
370 raise TypeError(
371 "Parameter then_branch is not of type "
372 "'if_then_else' but %r." % type(else_branch))
373 return OnnxVarGraph(
374 condition, then_branch=then_branch,
375 else_branch=else_branch, op=OnnxIf)
378def pad(x, pads, constant_value=None, mode='constant'):
379 """
380 It does not implement :func:`numpy.pad` but the ONNX version
381 :func:`onnx_pad <mlprodict.onnxrt.ops_cpu.op_pad.onnx_pad>`.
382 """
383 OnnxPad = loadop('Pad')
384 if constant_value is None:
385 return OnnxVar(x, pads, op=OnnxPad, mode=mode)
386 return OnnxVar(x, pads, constant_value, op=OnnxPad, mode=mode)
389def prod(x, axis=None, keepdims=0):
390 "See :func:`numpy.prod`."
391 OnnxReduceProd = loadop('ReduceProd')
392 if axis is None:
393 return OnnxVar(x, op=OnnxReduceProd, keepdims=keepdims)
394 if not isinstance(axis, list):
395 axis = [axis]
396 return OnnxVar(x, op=OnnxReduceProd, keepdims=keepdims, axes=axis)
399def relu(x):
400 "relu"
401 OnnxRelu = loadop('Relu')
402 return OnnxVar(x, op=OnnxRelu)
405def reciprocal(x):
406 "See :func:`numpy.reciprocal`."
407 OnnxReciprocal = loadop('Reciprocal')
408 return OnnxVar(x, op=OnnxReciprocal)
411def round(x):
412 "See :func:`numpy.round`."
413 OnnxRound = loadop('Round')
414 return OnnxVar(x, op=OnnxRound)
417def sigmoid(x):
418 "See :epkg:`scipy:special:expit`."
419 OnnxSigmoid = loadop('Sigmoid')
420 return OnnxVar(x, op=OnnxSigmoid)
423def sign(x):
424 "See :func:`numpy.sign`."
425 OnnxSign = loadop('Sign')
426 return OnnxVar(x, op=OnnxSign)
429def sin(x):
430 "See :func:`numpy.sin`."
431 OnnxSin = loadop('Sin')
432 return OnnxVar(x, op=OnnxSin)
435def sinh(x):
436 "See :func:`numpy.sinh`."
437 OnnxSinh = loadop('Sinh')
438 return OnnxVar(x, op=OnnxSinh)
441def sqrt(x):
442 "See :func:`numpy.sqrt`."
443 OnnxSqrt = loadop('Sqrt')
444 return OnnxVar(x, op=OnnxSqrt)
447def squeeze(x, axis=None):
448 "See :func:`numpy.squeeze`."
449 OnnxSqueeze = loadop('Squeeze')
450 if axis is None:
451 raise NotImplementedError( # pragma: no cover
452 "The case where all empty dimensions are removed is not "
453 "implemented.")
454 if isinstance(axis, int):
455 raise RuntimeError( # pragma: no cover
456 "axis must be a tensor.")
457 return OnnxVar(x, axis, op=OnnxSqueeze)
460def sum(x, axis=None, keepdims=0):
461 "See :func:`numpy.sum`."
462 OnnxReduceSum = loadop('ReduceSum')
463 if axis is None:
464 return OnnxVar(x, op=OnnxReduceSum, keepdims=keepdims)
465 return OnnxVar(x, numpy.array([axis], dtype=numpy.int64),
466 op=OnnxReduceSum, keepdims=keepdims)
469def tan(x):
470 "See :func:`numpy.tan`."
471 OnnxTan = loadop('Tan')
472 return OnnxVar(x, op=OnnxTan)
475def tanh(x):
476 "See :func:`numpy.tanh`."
477 OnnxTanh = loadop('Tanh')
478 return OnnxVar(x, op=OnnxTanh)
481def topk(x, k, axis=-1, largest=1, sorted=1):
482 "See :func:`numpy.argsort`."
483 OnnxTopK = loadop('TopK')
484 return xtuple(x, k, op=OnnxTopK, axis=axis, largest=largest,
485 sorted=sorted)
488def transpose(x, perm=(1, 0)):
489 "See :func:`numpy.transpose`."
490 OnnxTranspose = loadop('Transpose')
491 return OnnxVar(x, op=OnnxTranspose, perm=list(perm))
494def unsqueeze(x, axes):
495 "See :func:`numpy.expand_dims`."
496 OnnxUnsqueeze = loadop('Unsqueeze')
497 if isinstance(axes, int):
498 axes = numpy.array([axes], dtype=numpy.int64)
499 return OnnxVar(x, axes, op=OnnxUnsqueeze)
502def vstack(*x):
503 "See :func:`numpy.vstack`."
504 OnnxConcat = loadop('Concat')
505 if len(x) <= 1:
506 raise RuntimeError( # pragma: no cover
507 "N=%d<=1 elements to concatenate." % len(x))
508 return OnnxVar(*x, op=OnnxConcat, axis=0)
511def where(cond, x, y):
512 "See :func:`numpy.where`."
513 OnnxWhere = loadop('Where')
514 return OnnxVar(cond, x, y, op=OnnxWhere)