Coverage for mlprodict/onnxrt/ops_cpu/op_array_feature_extractor.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ..shape_object import ShapeObject
10from ._op_onnx_numpy import ( # pylint: disable=E0611,E0401
11 array_feature_extractor_double,
12 array_feature_extractor_int64,
13 array_feature_extractor_float)
16def _array_feature_extrator(data, indices):
17 """
18 Implementation of operator *ArrayFeatureExtractor*
19 with :epkg:`numpy`.
20 """
21 if len(indices.shape) == 2 and indices.shape[0] == 1:
22 index = indices.ravel().tolist()
23 add = len(index)
24 elif len(indices.shape) == 1:
25 index = indices.tolist()
26 add = len(index)
27 else:
28 add = 1
29 for s in indices.shape:
30 add *= s
31 index = indices.ravel().tolist()
32 if len(data.shape) == 1:
33 new_shape = (1, add)
34 else:
35 new_shape = list(data.shape[:-1]) + [add]
36 tem = data[..., index]
37 res = tem.reshape(new_shape)
38 return res
41def sizeof_dtype(dty):
42 if dty == numpy.float64:
43 return 8
44 if dty == numpy.float32:
45 return 4
46 if dty == numpy.int64:
47 return 8
48 raise ValueError(
49 "Unable to get bytes size for type {}.".format(numpy.dtype))
52class ArrayFeatureExtractor(OpRun):
54 def __init__(self, onnx_node, desc=None, **options):
55 OpRun.__init__(self, onnx_node, desc=desc,
56 **options)
58 def _run(self, data, indices): # pylint: disable=W0221
59 """
60 Runtime for operator *ArrayFeatureExtractor*.
62 .. warning::
63 ONNX specifications may be imprecise in some cases.
64 When the input data is a vector (one dimension),
65 the output has still two like a matrix with one row.
66 The implementation follows what :epkg:`onnxruntime` does in
67 `array_feature_extractor.cc
68 <https://github.com/microsoft/onnxruntime/blob/master/
69 onnxruntime/core/providers/cpu/ml/array_feature_extractor.cc#L84>`_.
70 """
71 if data.dtype == numpy.float64:
72 res = array_feature_extractor_double(data, indices)
73 elif data.dtype == numpy.float32:
74 res = array_feature_extractor_float(data, indices)
75 elif data.dtype == numpy.int64:
76 res = array_feature_extractor_int64(data, indices)
77 else:
78 # for strings, still not C++
79 res = _array_feature_extrator(data, indices)
80 return (res, )
82 def _infer_shapes(self, data, indices): # pylint: disable=W0221
83 """
84 Infer the shapes for the output.
85 """
86 add = indices.product()
88 if len(data) == 1:
89 dim = ShapeObject((1, add), dtype=data.dtype)
90 else:
91 dim = data.copy()
92 dim.append(add)
93 return (dim, )
95 def _infer_types(self, data, indices): # pylint: disable=W0221
96 """
97 Returns the type of the output.
98 """
99 return (data, )