Coverage for mlprodict/onnxrt/ops_cpu/op_topk.py: 86%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
10from ._op_onnx_numpy import ( # pylint: disable=E0611,E0401
11 topk_element_min_double, topk_element_max_double, topk_element_fetch_double,
12 topk_element_min_float, topk_element_max_float, topk_element_fetch_float,
13 topk_element_min_int64, topk_element_max_int64, topk_element_fetch_int64)
16def topk_sorted_implementation(X, k, axis, largest):
17 """
18 Retrieves the top-k elements.
20 @param X data
21 @param k k in top-k
22 @param axis axis chosen to select the top-k elements
23 @param largest largest (1) or smallest (0)
24 @return top-k values, top-k indices
26 See function `_kneighbors_reduce_func
27 <https://github.com/scikit-learn/scikit-learn/tree/master/
28 sklearn/neighbors/base.py#L304>`_.
29 """
30 if isinstance(k, numpy.ndarray):
31 if k.size != 1:
32 raise RuntimeError( # pragma: no cover
33 "k must be an integer not %r." % k)
34 k = k[0]
35 if len(X.shape) == 2 and axis == 1:
36 sample_range = numpy.arange(X.shape[0])[:, None]
37 if largest == 0:
38 sorted_indices = numpy.argpartition(X, axis=axis, kth=k - 1)
39 sorted_indices = sorted_indices[:, :k]
40 # argpartition doesn't guarantee sorted order, so we sort again
41 sorted_indices = sorted_indices[
42 sample_range, numpy.argsort(X[sample_range, sorted_indices])]
43 else:
44 sorted_indices = numpy.argpartition(-X, axis=axis, kth=k - 1)
45 sorted_indices = sorted_indices[:, :k]
46 # argpartition doesn't guarantee sorted order, so we sort again
47 sorted_indices = sorted_indices[
48 sample_range, numpy.argsort(-X[sample_range, sorted_indices])]
49 sorted_distances = X[sample_range, sorted_indices]
50 return sorted_distances, sorted_indices
52 sorted_indices = numpy.argsort(X, axis=axis)
53 sorted_values = numpy.sort(X, axis=axis)
54 if largest:
55 sorted_indices = numpy.flip(sorted_indices, axis=axis)
56 sorted_values = numpy.flip(sorted_values, axis=axis)
57 ark = numpy.arange(k)
58 topk_sorted_indices = numpy.take(sorted_indices, ark, axis=axis)
59 topk_sorted_values = numpy.take(sorted_values, ark, axis=axis)
60 return topk_sorted_values, topk_sorted_indices
63def topk_sorted_implementation_cpp(X, k, axis, largest, th_para=50):
64 """
65 Retrieves the top-k elements using a C++
66 implementation when the axis is the last dimension,
67 otherwise, it falls back to
68 @see fn topk_sorted_implementation.
70 @param X data
71 @param k k in top-k
72 @param axis axis chosen to select the top-k elements
73 @param largest largest (1) or smallest (0)
74 @param th_para threshold for parallelisation
75 @return top-k values, top-k indices
76 """
77 if isinstance(k, numpy.ndarray):
78 if k.size != 1:
79 raise RuntimeError( # pragma: no cover
80 "k must be an integer not %r." % k)
81 if axis != len(X.shape) - 1:
82 if k == 0:
83 return numpy.empty((0,), dtype=numpy.int64)
84 return topk_sorted_implementation(X, k, axis, largest)
85 if X.dtype == numpy.float64:
86 if k == 0:
87 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64)
88 if largest:
89 topk_sorted_indices = topk_element_max_double(X, k, True, th_para)
90 else:
91 topk_sorted_indices = topk_element_min_double(X, k, True, th_para)
92 topk_sorted_values = topk_element_fetch_double(X, topk_sorted_indices)
93 elif X.dtype == numpy.float32:
94 if k == 0:
95 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64)
96 if largest:
97 topk_sorted_indices = topk_element_max_float(X, k, True, th_para)
98 else:
99 topk_sorted_indices = topk_element_min_float(X, k, True, th_para)
100 topk_sorted_values = topk_element_fetch_float(X, topk_sorted_indices)
101 elif X.dtype == numpy.int64:
102 if k == 0:
103 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64)
104 if largest:
105 topk_sorted_indices = topk_element_max_int64(X, k, True, th_para)
106 else:
107 topk_sorted_indices = topk_element_min_int64(X, k, True, th_para)
108 topk_sorted_values = topk_element_fetch_int64(X, topk_sorted_indices)
109 else:
110 if k == 0:
111 return numpy.empty((0,), dtype=numpy.int64)
112 return topk_sorted_implementation(X, k, axis, largest)
113 return topk_sorted_values, topk_sorted_indices
116class _CommonTopK(OpRun):
117 """
118 Ths class hides a parameter used as a threshold above
119 which the parallelisation is started: ``th_para``.
120 """
122 atts = {'axis': -1}
124 def __init__(self, *args, **options):
125 OpRun.__init__(self, *args, **options)
126 self.th_para = 50
128 def _common_run(self, data, ink, largest=1): # pylint: disable=W0221
129 """
130 Runtime for operator *TopK*.
131 The implementation is not the most efficient
132 as it sorts everything then extracts the top *k*
133 values.
135 .. warning::
136 ONNX specifications may be imprecise in case of negative value
137 for axis. The implementation follows what :epkg:`onnxruntime`
138 does in `top_k.cc
139 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
140 """
141 k = ink[0]
142 axis = self.axis if self.axis >= 0 else (self.axis + len(data.shape))
143 sort, sorti = topk_sorted_implementation_cpp(
144 data, k, axis, largest, self.th_para)
145 return (sort, sorti.astype(numpy.int64))
147 def _infer_shapes(self, data, ink): # pylint: disable=W0221
148 axis = self.axis if self.axis >= 0 else (self.axis + len(data))
149 sh = data.copy()
150 pref = str(hex(id(self))[2:])
151 sh[axis] = "ntopk%s" % pref
152 shi = sh.copy(dtype=numpy.int64)
153 return (sh, shi)
155 def _infer_types(self, x, ink): # pylint: disable=E0202,W0221
156 return (x, numpy.int64)
159class TopK_1(_CommonTopK):
161 atts = {'axis': -1, 'k': None}
163 def __init__(self, onnx_node, desc=None, **options):
164 _CommonTopK.__init__(self, onnx_node, desc=desc,
165 expected_attributes=TopK_10.atts,
166 **options)
168 def _run(self, data): # pylint: disable=W0221
169 """
170 Runtime for operator *TopK*.
171 The implementation is not the most efficient
172 as it sorts everything then extracts the top *k*
173 values.
175 .. warning::
176 ONNX specifications may be imprecise in case of negative value
177 for axis. The implementation follows what :epkg:`onnxruntime`
178 does in `top_k.cc
179 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
180 """
181 return _CommonTopK._common_run(self, data, [self.k])
183 def _infer_shapes(self, data): # pylint: disable=W0221
184 return _CommonTopK._infer_shapes(self, data, [self.k])
186 def _infer_types(self, data): # pylint: disable=W0221
187 return (data, )
189 def _infer_sizes(self, *args): # pylint: disable=W0221
190 res = self.run(*args)
191 x = args[0]
192 return (dict(temp=x.dtype.itemsize * self.k * 2), ) + res
195class TopK_10(_CommonTopK):
197 atts = {'axis': -1}
199 def __init__(self, onnx_node, desc=None, **options):
200 _CommonTopK.__init__(self, onnx_node, desc=desc,
201 expected_attributes=TopK_10.atts,
202 **options)
204 def _run(self, data, ink): # pylint: disable=W0221
205 """
206 Runtime for operator *TopK*.
207 The implementation is not the most efficient
208 as it sorts everything then extracts the top *k*
209 values.
211 .. warning::
212 ONNX specifications may be imprecise in case of negative value
213 for axis. The implementation follows what :epkg:`onnxruntime`
214 does in `top_k.cc
215 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
216 """
217 return _CommonTopK._common_run(self, data, ink)
219 def _infer_sizes(self, data, ink): # pylint: disable=W0221
220 res = self.run(data, ink)
221 return (dict(temp=data.dtype.itemsize * ink[0] * 2), ) + res
224class TopK_11(_CommonTopK):
226 atts = {'axis': -1, 'largest': 1, 'sorted': 1}
228 def __init__(self, onnx_node, desc=None, **options):
229 _CommonTopK.__init__(self, onnx_node, desc=desc,
230 expected_attributes=TopK_11.atts,
231 **options)
232 if self.sorted not in (True, 1):
233 raise RuntimeError( # pragma: no cover
234 "TopK does not implement anything for sorted=0.")
236 def _run(self, data, ink): # pylint: disable=W0221
237 """
238 Runtime for operator *TopK*.
239 The implementation is not the most efficient
240 as it sorts everything then extracts the top *k*
241 values.
243 .. warning::
244 ONNX specifications may be imprecise in case of negative value
245 for axis. The implementation follows what :epkg:`onnxruntime`
246 does in `top_k.cc
247 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
248 """
249 return _CommonTopK._common_run(self, data, ink, self.largest)
251 def _infer_sizes(self, data, ink): # pylint: disable=W0221
252 res = self.run(data, ink)
253 return (dict(temp=data.dtype.itemsize * ink[0] * 2), ) + res
256if onnx_opset_version() >= 11:
257 TopK = TopK_11
258elif onnx_opset_version() >= 10: # pragma: no cover
259 TopK = TopK_10
260else: # pragma: no cover
261 TopK = TopK_1