Scan - 8 vs 9#
Next section compares an older to a newer version of the same operator after both definition are converted into markdown text. Green means an addition to the newer version, red means a deletion. Anything else is unchanged.
- Scan8 → Scan9 +66 -71
Scan8 → Scan9
RENAMED
@@ -1 +1 @@
|
|
1
1
|
Scan can be used to iterate over one or more scan_input tensors,
|
2
2
|
constructing zero or more scan_output tensors. It combines ideas from general recurrences,
|
3
3
|
functional programming constructs such as scan, fold, map, and zip, and is intended to enable
|
4
4
|
generalizations of RNN-like constructs for sequence-to-sequence processing.
|
5
5
|
Other tensors (referred to as state_variables here) can be used to carry a state
|
6
6
|
when iterating from one element to another (similar to hidden-state in RNNs, also referred
|
7
|
-
to as loop-carried dependences in the context of loops).
|
7
|
+
to as loop-carried dependences in the context of loops). All these tensors are required to
|
8
|
+
have the same shape in each iteration of the loop (a restriction imposed to enable efficient
|
8
|
-
Many common usages involve a single scan_input tensor (where functionality
|
9
|
+
memory allocation). Many common usages involve a single scan_input tensor (where functionality
|
9
10
|
similar to scan, fold and map can be obtained). When more than one scan_input is used,
|
10
11
|
a behavior similar to zip is obtained.
|
11
12
|
The attribute body must be a graph, specifying the computation to be performed in
|
12
13
|
every iteration. It takes as input the current values of the state_variables and
|
13
14
|
the current iterated element of the scan_inputs. It must return the (updated) values
|
14
15
|
of the state_variables and zero or more scan_output_element tensors. The values of the
|
15
16
|
scan_output_element tensors are concatenated over all the iterations to produce the
|
16
17
|
scan_output values of the scan construct (similar to the concatenated intermediate
|
18
|
+
hidden-state values of RNN-like constructs).
|
17
|
-
hidden-state values of RNN-like constructs). All the output tensors (state_variables as
|
18
|
-
well as scan_output_element tensors) are required to have the same shape in each iteration
|
19
|
-
of the loop (a restriction imposed to enable efficient memory allocation).
|
20
|
-
|
21
|
-
Note that the iterated element passed to the body subgraph does not have a sequence
|
22
|
-
axis. It will have a rank one less than the rank of the corresponding scan_input.
|
23
19
|
The scan operation returns the final values of the state_variables as well as the
|
24
20
|
scan_outputs.
|
21
|
+
The operation supports batching, and the batch-axis is required to be 0.
|
25
|
-
|
22
|
+
When multiple scan_input tensors are used, they must all have the same batch-size,
|
23
|
+
and they must all have the same maximum-sequence-length (the dimensionality of the
|
26
|
-
|
24
|
+
sequence axis or scan axis). The sequence axis or scan axis is required to be 1.
|
27
|
-
direction. A bidirectional scan may be performed by specifying the same tensor input twice
|
28
|
-
in the scan_inputs, once with a forward direction, and once with a backward direction.
|
25
|
+
The operation has an optional sequence_lens input (of shape [BATCH_SIZE]) to
|
26
|
+
allow variable length sequences of length <= the maximum-sequence-length. If this
|
27
|
+
input is not specified, all sequences are assumed to be of length equal to
|
28
|
+
maximum-sequence-length. For variable length input sequences, the scan_outputs
|
29
|
+
will consist of a sequence of same length as the input, padded to the
|
30
|
+
maximum-sequence-length.
|
29
|
-
The
|
31
|
+
The optional attribute directions can be used to scan a sequence in the reverse direction.
|
30
|
-
values produced by the body in each iteration. The optional attribute scan_output_directions
|
31
|
-
specifies the direction in which scan_output is constructed (by appending or prepending the
|
32
|
-
scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
|
33
|
-
is omitted,
|
32
|
+
If this attribute is omitted, all sequences are scanned in the forward direction.
|
33
|
+
A bidirectional scan be performed by specifying the same tensor input twice in the
|
34
|
+
scan_inputs, once with a forward direction, and once with a backward direction.
|
34
|
-
The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
|
35
|
-
If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
|
36
|
-
batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
|
37
|
-
Note that scanning a non-zero axis may be less efficient than scanning axis zero.
|
38
|
-
|
39
|
-
The optional attribute scan_output_axes specifies the axis along which the scan_outputs
|
40
|
-
are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
|
41
|
-
scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
|
42
|
-
value of 1.
|
43
35
|
Note that because of the ONNX restriction that only the last parameter of an operator can
|
44
36
|
be variadic, the initial-states and scan-inputs are listed together as one input parameter.
|
45
37
|
Similarly, the final-states and scan-outputs are listed together as one output parameter.
|
46
38
|
The attribute num_scan_inputs indicates the number M of scan-inputs.
|
47
39
|
The behavior of
|
48
40
|
Scan <
|
49
41
|
num_scan_inputs = m,
|
50
|
-
body = loop-body
|
42
|
+
body = loop-body
|
51
|
-
scan_input_axes = [axis_1, ..., axis_m]
|
52
|
-
> (init_1, ..., init_n, scan_1, ..., scan_m)
|
43
|
+
> (sequence_lengths, init_1, ..., init_n, scan_1, ..., scan_m)
|
53
44
|
is equivalent to the following pseudo-code:
|
45
|
+
// T.shape[0] denotes the batch-size of T
|
46
|
+
// The batch-size of scan_1, ..., scan_m are all required to be equal
|
47
|
+
batch_size = scan_1.shape[0];
|
54
|
-
// scan_i.shape[
|
48
|
+
// scan_i.shape[1] denotes the (max) sequence-length of scan_i
|
55
|
-
// scan_i.shape[
|
49
|
+
// scan_i.shape[1] is required to be equal to scan_j.shape[1] for all i,j.
|
56
|
-
|
50
|
+
max_sequence_length = scan_1.shape[1];
|
51
|
+
for (int batch = 0; batch < batch_size; ++batch) {
|
57
|
-
|
52
|
+
// initialize state-variables
|
58
|
-
|
53
|
+
st_1 = init_1; ... st_n = init_n;
|
59
|
-
|
54
|
+
// initialize scan-output variables: [] denotes an empty tensor
|
60
|
-
|
55
|
+
scan_out_1 = []; ...; scan_out_k = [];
|
61
|
-
|
56
|
+
// identify number of iterations:
|
57
|
+
N = (sequence_lengths specified) ? sequence_lengths[batch] : max_sequence_length;
|
58
|
+
|
62
|
-
|
59
|
+
// execute loop
|
63
|
-
|
60
|
+
for (int t = 0; t < N; ++t) {
|
64
|
-
|
61
|
+
// generate the scan-input elements: the notation T<axis=k>[t] indicates the sub-tensor
|
65
|
-
|
62
|
+
// of rank one less than T obtained by indexing T at position t along axis k.
|
66
|
-
|
63
|
+
si_1 = (scan_1<axis=0>[batch])<axis=1>[t];
|
67
|
-
|
64
|
+
... ;
|
68
|
-
|
65
|
+
si_m = (scan_m<axis=0>[batch])<axis=1>[t];
|
69
|
-
|
66
|
+
// execute loop-body
|
70
|
-
|
67
|
+
st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m)
|
71
|
-
|
68
|
+
// accumulate the scan-output elements
|
72
|
-
|
69
|
+
scan_out_1 = Concat<axis=0>(scan_out_1, so_1); ... ; scan_out_k = Concat<axis=0>(scan_out_k, so_k);
|
70
|
+
}
|
71
|
+
// accumulate the outputs for this batch:
|
72
|
+
bst_1[batch] = st_1; ..., bst_n[batch] = st_n;
|
73
|
+
// Note scan-outputs will have size max_sequence_length, but only first N values will be meaningful.
|
74
|
+
// The remaining values have an undefined value.
|
75
|
+
b_scan_out_1[batch] = scan_out_1; ...; b_scan_out_k[batch] = scan_out_k;
|
73
76
|
}
|
74
|
-
|
75
|
-
return
|
77
|
+
return bst_1, ..., bst_n, b_scan_out_1, ..., b_scan_out_k;
|
76
78
|
*Sample usage: Encoding RNN using a Scan*
|
77
79
|
The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi,
|
78
80
|
recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can
|
79
81
|
be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes
|
80
82
|
%Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these
|
81
83
|
values are computed in the outer graph, they need to be passed in as extra state_variables.
|
82
84
|
graph rnn-encoding {
|
83
85
|
%H_0 = ...
|
84
86
|
%X = ...
|
85
|
-
%Y_h, %Y = Scan[body = <graph rnn-cell-1>, num_scan_inputs=1](%H_0, %X)
|
87
|
+
%Y_h, %Y = Scan[body = <graph rnn-cell-1>, num_scan_inputs=1]("", %H_0, %X)
|
86
88
|
return %Y, %Y_h
|
87
89
|
}
|
88
90
|
graph rnn-cell-1 (
|
89
91
|
%H_tminus1[FLOAT, tensor]
|
90
92
|
%X_t[FLOAT, tensor]
|
91
93
|
) {
|
92
94
|
%Wi = ...
|
93
95
|
%Ri = ...
|
94
96
|
%Wbi = ...
|
95
97
|
%Rbi = ...
|
96
98
|
%t1 = X_t * (Wi^T)
|
97
99
|
%t2 = H_tminus1*(Ri^T)
|
98
100
|
%t3 = Add(%t1, %t2)
|
99
101
|
%t4 = Add(%t3, %Wbi)
|
100
102
|
%t5 = Add(%t4, %Rbi)
|
101
103
|
%Ht = Tanh(%t5)
|
102
104
|
%Accumulate = Identity(%Ht)
|
103
105
|
return %Ht, %Accumulate
|
104
106
|
}
|
105
107
|
**Attributes**
|
106
108
|
* **body** (required):
|
107
109
|
The graph run each iteration. It has N+M inputs: (loop state
|
108
110
|
variables..., scan_input_elts...). It has N+K outputs: (loop state
|
109
111
|
variables..., scan_output_elts...). Each scan_output is created by
|
110
112
|
concatenating the value of the specified scan_output_elt value at
|
111
113
|
the end of each iteration of the loop. It is an error if the
|
112
114
|
dimensions of these values change across loop iterations.
|
113
|
-
* **num_scan_inputs** (required):
|
114
|
-
An attribute specifying the number of scan_inputs M.
|
115
|
-
* **scan_input_axes**:
|
116
|
-
An optional list of M flags. The i-th element of the list specifies
|
117
|
-
the axis to be scanned (the sequence axis) for the i-th scan_input.
|
118
|
-
If omitted, 0 will be used as the scan axis for every scan_input.
|
119
|
-
* **
|
115
|
+
* **directions**:
|
120
116
|
An optional list of M flags. The i-th element of the list specifies
|
121
117
|
the direction to be scanned for the i-th scan_input tensor: 0
|
122
118
|
indicates forward direction and 1 indicates reverse direction. If
|
123
119
|
omitted, all scan_input tensors will be scanned in the forward
|
124
120
|
direction.
|
121
|
+
* **num_scan_inputs** (required):
|
122
|
+
An attribute specifying the number of scan_inputs M.
|
125
|
-
* **scan_output_axes**:
|
126
|
-
An optional list of K flags. The i-th element of the list specifies
|
127
|
-
the axis for the i-th scan_output. The scan outputs are accumulated
|
128
|
-
along the specified axis. If omitted, 0 will be used as the scan
|
129
|
-
axis for every scan_output.
|
130
|
-
* **scan_output_directions**:
|
131
|
-
An optional list of K flags, one for each scan_output. The i-th
|
132
|
-
element of the list specifies whether the i-th scan_output should be
|
133
|
-
constructed by appending or prepending a new value in each
|
134
|
-
iteration: 0 indicates appending and 1 indicates prepending. If
|
135
|
-
omitted, all scan_output tensors will be produced by appending a
|
136
|
-
value in each iteration.
|
137
123
|
**Inputs**
|
138
|
-
Between
|
124
|
+
Between 2 and 2147483647 inputs.
|
125
|
+
* **sequence_lens** (optional, heterogeneous) - **I**:
|
126
|
+
Optional tensor specifying lengths of the sequences in a batch. If
|
127
|
+
this input is not specified, all sequences are assumed to be of the
|
128
|
+
maximum sequence length (the dimension of the sequence axis of the
|
129
|
+
scan_input tensors).
|
139
130
|
* **initial_state_and_scan_inputs** (variadic) - **V**:
|
140
131
|
Initial values of the loop's N state variables followed by M
|
141
132
|
scan_inputs
|
142
133
|
**Outputs**
|
143
134
|
Between 1 and 2147483647 outputs.
|
144
135
|
* **final_state_and_scan_outputs** (variadic) - **V**:
|
145
136
|
Final values of the loop's N state variables followed by K
|
146
137
|
scan_outputs
|
147
138
|
**Type Constraints**
|
139
|
+
* **I** in (
|
140
|
+
tensor(int64)
|
141
|
+
):
|
142
|
+
Int64 tensor
|
148
143
|
* **V** in (
|
149
144
|
tensor(bool),
|
150
145
|
tensor(complex128),
|
151
146
|
tensor(complex64),
|
152
147
|
tensor(double),
|
153
148
|
tensor(float),
|
154
149
|
tensor(float16),
|
155
150
|
tensor(int16),
|
156
151
|
tensor(int32),
|
157
152
|
tensor(int64),
|
158
153
|
tensor(int8),
|
159
154
|
tensor(string),
|
160
155
|
tensor(uint16),
|
161
156
|
tensor(uint32),
|
162
157
|
tensor(uint64),
|
163
158
|
tensor(uint8)
|
164
159
|
):
|
165
160
|
All Tensor types
|