Coverage for mlprodict/onnxrt/ops_cpu/op_rnn.py: 85%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
10from ..shape_object import ShapeObject
13class CommonRNN(OpRun):
15 def __init__(self, onnx_node, expected_attributes=None, desc=None,
16 **options):
17 OpRun.__init__(self, onnx_node, desc=desc,
18 expected_attributes=expected_attributes,
19 **options)
21 if self.direction in ("forward", "reverse"):
22 self.num_directions = 1
23 elif self.direction == "bidirectional":
24 self.num_directions = 2
25 else:
26 raise RuntimeError( # pragma: no cover
27 "Unknown direction '{}'.".format(self.direction))
29 if len(self.activation_alpha) != self.num_directions:
30 raise RuntimeError( # pragma: no cover
31 "activation_alpha must have the same size as num_directions={}".format(
32 self.num_directions))
33 if len(self.activation_beta) != self.num_directions:
34 raise RuntimeError( # pragma: no cover
35 "activation_beta must have the same size as num_directions={}".format(
36 self.num_directions))
38 self.f1 = self.choose_act(self.activations[0],
39 self.activation_alpha[0],
40 self.activation_beta[0])
41 if len(self.activations) > 1:
42 self.f2 = self.choose_act(self.activations[1],
43 self.activation_alpha[1],
44 self.activation_beta[1])
45 self.nb_outputs = len(onnx_node.output)
46 if getattr(self, 'layout', 0) != 0:
47 raise NotImplementedError(
48 "The runtime is not implemented when layout=%r != 0." % self.layout)
50 def choose_act(self, name, alpha, beta):
51 if name == b"Tanh":
52 return self._f_tanh
53 if name == b"Affine":
54 return lambda x: x * alpha + beta
55 raise RuntimeError( # pragma: no cover
56 "Unknown activation function '{}'.".format(name))
58 def _f_tanh(self, x):
59 return numpy.tanh(x)
61 def _step(self, X, R, B, W, H_0):
62 h_list = []
63 H_t = H_0
64 for x in numpy.split(X, X.shape[0], axis=0):
65 H = self.f1(numpy.dot(x, numpy.transpose(W)) +
66 numpy.dot(H_t, numpy.transpose(R)) +
67 numpy.add(*numpy.split(B, 2)))
68 h_list.append(H)
69 H_t = H
70 concatenated = numpy.concatenate(h_list)
71 if self.num_directions == 1:
72 output = numpy.expand_dims(concatenated, 1)
73 return output, h_list[-1]
75 def _run(self, X, W, R, B=None, sequence_lens=None, initial_h=None): # pylint: disable=W0221
76 self.num_directions = W.shape[0]
78 if self.num_directions == 1:
79 R = numpy.squeeze(R, axis=0)
80 W = numpy.squeeze(W, axis=0)
81 if B is not None:
82 B = numpy.squeeze(B, axis=0)
83 if sequence_lens is not None:
84 sequence_lens = numpy.squeeze(sequence_lens, axis=0)
85 if initial_h is not None:
86 initial_h = numpy.squeeze(initial_h, axis=0)
88 hidden_size = R.shape[-1]
89 batch_size = X.shape[1]
91 b = (B if B is not None else
92 numpy.zeros(2 * hidden_size, dtype=numpy.float32))
93 h_0 = (initial_h if initial_h is not None else
94 numpy.zeros((batch_size, hidden_size), dtype=numpy.float32))
96 B = b
97 H_0 = h_0
98 else:
99 raise NotImplementedError() # pragma: no cover
101 Y, Y_h = self._step(X, R, B, W, H_0)
102 return (Y, ) if self.nb_outputs == 1 else (Y, Y_h)
104 def _infer_shapes(self, X, W, R, B=None, sequence_lens=None, initial_h=None): # pylint: disable=W0221
105 num_directions = W.shape[0]
107 if num_directions == 1:
108 hidden_size = R[-1]
109 batch_size = X[1]
110 y_shape = ShapeObject((X[0], num_directions, batch_size, hidden_size),
111 dtype=X.dtype)
112 else:
113 raise NotImplementedError() # pragma: no cover
114 if self.nb_outputs == 1:
115 return (y_shape, )
116 y_h_shape = ShapeObject((num_directions, batch_size, hidden_size),
117 dtype=X.dtype)
118 return (y_shape, y_h_shape)
120 def _infer_types(self, X, W, R, B=None, sequence_lens=None, initial_h=None): # pylint: disable=W0221
121 return (X, X)
124class RNN_7(CommonRNN):
126 atts = {
127 'activation_alpha': [0.],
128 'activation_beta': [0.],
129 'activations': ['tanh', 'tanh'],
130 'clip': [],
131 'direction': 'forward',
132 'hidden_size': None,
133 }
135 def __init__(self, onnx_node, desc=None, **options):
136 CommonRNN.__init__(self, onnx_node, desc=desc,
137 expected_attributes=RNN_7.atts,
138 **options)
141class RNN_14(CommonRNN):
143 atts = {
144 'activation_alpha': [0.],
145 'activation_beta': [0.],
146 'activations': ['tanh', 'tanh'],
147 'clip': [],
148 'direction': 'forward',
149 'hidden_size': None,
150 'layout': 0,
151 }
153 def __init__(self, onnx_node, desc=None, **options):
154 CommonRNN.__init__(self, onnx_node, desc=desc,
155 expected_attributes=RNN_14.atts,
156 **options)
159if onnx_opset_version() >= 14:
160 RNN = RNN_14
161else: # pragma: no cover
162 RNN = RNN_7