Coverage for mlprodict/onnxrt/ops_cpu/_op_numpy_helper.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief numpy redundant functions.
4"""
5import numpy
6from scipy.sparse.coo import coo_matrix
9def numpy_dot_inplace(inplaces, a, b):
10 """
11 Implements a dot product, deals with inplace information.
12 See :epkg:`numpy:dot`.
13 """
14 if inplaces.get(0, False) and hasattr(a, 'flags'):
15 return _numpy_dot_inplace_left(a, b)
16 if inplaces.get(1, False) and hasattr(b, 'flags'):
17 return _numpy_dot_inplace_right(a, b)
18 return numpy.dot(a, b)
21def _numpy_dot_inplace_left(a, b):
22 "Subpart of @see fn numpy_dot_inplace."
23 if a.flags['F_CONTIGUOUS']:
24 if len(b.shape) == len(a.shape) == 2 and b.shape[1] <= a.shape[1]:
25 try:
26 numpy.dot(a, b, out=a[:, :b.shape[1]])
27 return a[:, :b.shape[1]]
28 except ValueError:
29 return numpy.dot(a, b)
30 if len(b.shape) == 1:
31 try:
32 numpy.dot(a, b.reshape(b.shape[0], 1), out=a[:, :1])
33 return a[:, :1].reshape(a.shape[0])
34 except ValueError: # pragma no cover
35 return numpy.dot(a, b)
36 return numpy.dot(a, b)
39def _numpy_dot_inplace_right(a, b):
40 "Subpart of @see fn numpy_dot_inplace."
41 if b.flags['C_CONTIGUOUS']:
42 if len(b.shape) == len(a.shape) == 2 and a.shape[0] <= b.shape[0]:
43 try:
44 numpy.dot(a, b, out=b[:a.shape[0], :])
45 return b[:a.shape[0], :]
46 except ValueError: # pragma no cover
47 return numpy.dot(a, b)
48 if len(a.shape) == 1:
49 try:
50 numpy.dot(a, b, out=b[:1, :])
51 return b[:1, :]
52 except ValueError: # pragma no cover
53 return numpy.dot(a, b)
54 return numpy.dot(a, b)
57def numpy_matmul_inplace(inplaces, a, b):
58 """
59 Implements a matmul product, deals with inplace information.
60 See :epkg:`numpy:matmul`.
61 Inplace computation does not work well as modifying one of the
62 container modifies the results. This part still needs to be
63 improves.
64 """
65 try:
66 if isinstance(a, coo_matrix) or isinstance(b, coo_matrix):
67 return numpy.dot(a, b)
68 if len(a.shape) <= 2 and len(b.shape) <= 2:
69 return numpy_dot_inplace(inplaces, a, b)
70 return numpy.matmul(a, b)
71 except ValueError as e: # pragma: no cover
72 raise ValueError(
73 "Unable to multiply shapes %r, %r." % (a.shape, b.shape)) from e