Coverage for mlprodict/onnxrt/ops_cpu/op_celu.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRunUnaryNum
11def pycelu(x, alpha=1.):
12 """
13 Computes function ``celu(x)``.
15 .. math::
17 celu(x) = \\left \\{\\begin{array}{ll} x \\text{ if } x > 0 \\\\
18 \\alpha ( e^{\\frac{x}{\\alpha}} - 1) \\, \\text{ otherwise }
19 \\end{array} \\right.
20 """
21 if x > 0:
22 return x
23 return (numpy.exp(x / alpha) - 1) * alpha
26def _vcelu1(x, alpha=1.):
27 positive_input = numpy.maximum(0, x)
28 negative_input = numpy.minimum(0, alpha * (
29 numpy.exp(x / alpha) - 1))
30 return positive_input + negative_input
33class Celu(OpRunUnaryNum):
35 atts = {'alpha': numpy.float32(1.0)}
37 def __init__(self, onnx_node, desc=None, **options):
38 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
39 expected_attributes=Celu.atts,
40 **options)
41 self._vcelu2 = numpy.vectorize(
42 lambda x: pycelu(x, self.alpha), otypes=[numpy.float])
44 def _run(self, x): # pylint: disable=W0221
45 if self.inplaces.get(0, False):
46 return self._run_inplace(x)
47 return (_vcelu1(x, self.alpha), )
49 def _run_inplace(self, x):
50 return (self._vcelu2(x), )
52 def to_python(self, inputs):
53 return ('from mlprodict.onnxrt.ops_cpu.op_celu import _vcelu1',
54 "return _vcelu1(X, alpha)")