Coverage for mlprodict/onnxrt/ops_cpu/op_cdist.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from scipy.spatial.distance import cdist
8from ._op import OpRunBinaryNum
9from ._new_ops import OperatorSchema
10from ..shape_object import ShapeObject
13class CDist(OpRunBinaryNum):
15 atts = {'metric': 'sqeuclidean', 'p': 2.}
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRunBinaryNum.__init__(self, onnx_node, desc=desc,
19 expected_attributes=CDist.atts,
20 **options)
22 def _run(self, a, b): # pylint: disable=W0221
23 metric = self.metric.decode('ascii')
24 if metric == 'minkowski':
25 res = cdist(a, b, metric=metric, p=self.p)
26 else:
27 res = cdist(a, b, metric=metric)
28 # scipy may change the output type
29 res = res.astype(a.dtype)
30 return (res, )
32 def _find_custom_operator_schema(self, op_name):
33 if op_name == "CDist":
34 return CDistSchema()
35 raise RuntimeError( # pragma: no cover
36 "Unable to find a schema for operator '{}'.".format(op_name))
38 def _infer_shapes(self, a, b): # pylint: disable=W0221,W0237
39 """
40 Returns the same for the labels and the probabilities.
41 """
42 return (ShapeObject((a[0], b[0]), dtype=a.dtype,
43 name=self.__class__.__name__), )
45 def to_python(self, inputs):
46 metric = self.metric.decode('ascii')
47 if metric == 'minkowski':
48 return ('from scipy.spatial.distance import cdist',
49 "return cdist({}, {}, metric='{}', p={})".format(
50 inputs[0], inputs[1], metric, self.p))
51 return ('from scipy.spatial.distance import cdist',
52 "return cdist({}, {}, metric='{}')".format(
53 inputs[0], inputs[1], metric))
56class CDistSchema(OperatorSchema):
57 """
58 Defines a schema for operators added in this package
59 such as @see cl TreeEnsembleClassifierDouble.
60 """
62 def __init__(self):
63 OperatorSchema.__init__(self, 'CDist')
64 self.attributes = CDist.atts