Coverage for mlprodict/onnxrt/ops_cpu/op_solve.py: 92%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from scipy.linalg import solve
8from ._op import OpRunBinaryNum
9from ._new_ops import OperatorSchema
12class Solve(OpRunBinaryNum):
14 atts = {'lower': False,
15 'transposed': False}
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRunBinaryNum.__init__(self, onnx_node, desc=desc,
19 expected_attributes=Solve.atts,
20 **options)
22 def _find_custom_operator_schema(self, op_name):
23 if op_name == "Solve":
24 return SolveSchema()
25 raise RuntimeError( # pragma: no cover
26 "Unable to find a schema for operator '{}'.".format(op_name))
28 def _run(self, a, b): # pylint: disable=W0221
29 if self.inplaces.get(1, False):
30 return (solve(a, b, overwrite_b=True, lower=self.lower,
31 transposed=self.transposed), )
32 return (solve(a, b, lower=self.lower, transposed=self.transposed), )
34 def _infer_shapes(self, a, b): # pylint: disable=W0221,W0237
35 return (b, )
37 def _infer_types(self, a, b): # pylint: disable=W0221,W0237
38 return (b, )
40 def to_python(self, inputs):
41 return ('from scipy.linalg import solve',
42 "return solve({}, {}, lower={}, transposed={})".format(
43 inputs[0], inputs[1], self.lower, self.transposed))
46class SolveSchema(OperatorSchema):
47 """
48 Defines a schema for operators added in this package
49 such as @see cl Solve.
50 """
52 def __init__(self):
53 OperatorSchema.__init__(self, 'Solve')
54 self.attributes = Solve.atts