Coverage for mlprodict/onnxrt/ops_cpu/op_leaky_relu.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from textwrap import dedent
8from ._op import OpRunUnaryNum
11def _leaky_relu(x, alpha):
12 sign = (x > 0).astype(x.dtype)
13 sign -= ((sign - 1) * alpha).astype(x.dtype)
14 return x * sign
17def _leaky_relu_inplace(x, alpha):
18 sign = (x > 0).astype(x.dtype)
19 sign -= ((sign - 1) * alpha).astype(x.dtype)
20 x *= sign
23class LeakyRelu(OpRunUnaryNum):
25 atts = {'alpha': 0.01}
27 def __init__(self, onnx_node, desc=None, **options):
28 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
29 expected_attributes=LeakyRelu.atts,
30 **options)
32 def _run(self, x): # pylint: disable=W0221
33 if self.inplaces.get(0, False):
34 return self._run_inplace(x)
35 return (_leaky_relu(x, self.alpha), )
37 def _run_inplace(self, x):
38 _leaky_relu_inplace(x, self.alpha)
39 return (x, )
41 def to_python(self, inputs):
42 return (dedent(
43 """
44 import numpy
45 def _leaky_relu(x, alpha):
46 sign = (x > 0).astype(x.dtype)
47 sign -= ((sign - 1) * alpha).astype(x.dtype)
48 return x * sign
49 """), "return _leaky_relu(%s, alpha)" % inputs[0])