Scan - 9 vs 11#
Next section compares an older to a newer version of the same operator after both definition are converted into markdown text. Green means an addition to the newer version, red means a deletion. Anything else is unchanged.
- Scan9 → Scan11 +1 -4
Scan9 → Scan11
RENAMED
@@ -1 +1 @@
|
|
1
1
|
Scan can be used to iterate over one or more scan_input tensors,
|
2
2
|
constructing zero or more scan_output tensors. It combines ideas from general recurrences,
|
3
3
|
functional programming constructs such as scan, fold, map, and zip, and is intended to enable
|
4
4
|
generalizations of RNN-like constructs for sequence-to-sequence processing.
|
5
5
|
Other tensors (referred to as state_variables here) can be used to carry a state
|
6
6
|
when iterating from one element to another (similar to hidden-state in RNNs, also referred
|
7
7
|
to as loop-carried dependences in the context of loops).
|
8
8
|
Many common usages involve a single scan_input tensor (where functionality
|
9
9
|
similar to scan, fold and map can be obtained). When more than one scan_input is used,
|
10
10
|
a behavior similar to zip is obtained.
|
11
11
|
The attribute body must be a graph, specifying the computation to be performed in
|
12
12
|
every iteration. It takes as input the current values of the state_variables and
|
13
13
|
the current iterated element of the scan_inputs. It must return the (updated) values
|
14
14
|
of the state_variables and zero or more scan_output_element tensors. The values of the
|
15
15
|
scan_output_element tensors are concatenated over all the iterations to produce the
|
16
16
|
scan_output values of the scan construct (similar to the concatenated intermediate
|
17
17
|
hidden-state values of RNN-like constructs). All the output tensors (state_variables as
|
18
18
|
well as scan_output_element tensors) are required to have the same shape in each iteration
|
19
19
|
of the loop (a restriction imposed to enable efficient memory allocation).
|
20
20
|
Note that the iterated element passed to the body subgraph does not have a sequence
|
21
21
|
axis. It will have a rank one less than the rank of the corresponding scan_input.
|
22
22
|
The scan operation returns the final values of the state_variables as well as the
|
23
23
|
scan_outputs.
|
24
24
|
The optional attribute scan_input_directions specifies the direction (forward or backward)
|
25
25
|
for each scan input. If this attribute is omitted, all sequences are scanned in the forward
|
26
26
|
direction. A bidirectional scan may be performed by specifying the same tensor input twice
|
27
27
|
in the scan_inputs, once with a forward direction, and once with a backward direction.
|
28
28
|
The scan_output of the operation is produced by concatenating the scan_output_element
|
29
29
|
values produced by the body in each iteration. The optional attribute scan_output_directions
|
30
30
|
specifies the direction in which scan_output is constructed (by appending or prepending the
|
31
31
|
scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
|
32
32
|
is omitted, the scan_output_element is appended to the scan_output in each iteration.
|
33
33
|
The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
|
34
34
|
If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
|
35
35
|
batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
|
36
36
|
Note that scanning a non-zero axis may be less efficient than scanning axis zero.
|
37
37
|
The optional attribute scan_output_axes specifies the axis along which the scan_outputs
|
38
38
|
are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
|
39
39
|
scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
|
40
40
|
value of 1.
|
41
41
|
Note that because of the ONNX restriction that only the last parameter of an operator can
|
42
42
|
be variadic, the initial-states and scan-inputs are listed together as one input parameter.
|
43
43
|
Similarly, the final-states and scan-outputs are listed together as one output parameter.
|
44
44
|
The attribute num_scan_inputs indicates the number M of scan-inputs.
|
45
45
|
The behavior of
|
46
46
|
Scan <
|
47
47
|
num_scan_inputs = m,
|
48
48
|
body = loop-body,
|
49
49
|
scan_input_axes = [axis_1, ..., axis_m]
|
50
50
|
> (init_1, ..., init_n, scan_1, ..., scan_m)
|
51
51
|
is equivalent to the following pseudo-code:
|
52
52
|
// scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i
|
53
53
|
// scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j.
|
54
54
|
sequence_length = scan_1.shape[axis_1];
|
55
55
|
// initialize state-variables
|
56
56
|
st_1 = init_1; ... st_n = init_n;
|
57
57
|
// initialize scan-output variables: [] denotes an empty tensor
|
58
58
|
scan_out_1 = []; ...; scan_out_k = [];
|
59
59
|
// identify number of iterations:
|
60
60
|
// execute loop
|
61
61
|
for (int t = 0; t < sequence_length; ++t) {
|
62
62
|
// generate the scan-input elements: the notation T<axis=k>[t] indicates the sub-tensor
|
63
63
|
// of rank one less than T obtained by indexing T at position t along axis k.
|
64
64
|
si_1 = scan_1<axis=axis_1>[t];
|
65
65
|
... ;
|
66
66
|
si_m = scan_m<axis=axis_m>[t];
|
67
67
|
// execute loop-body
|
68
68
|
st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m)
|
69
69
|
// accumulate the scan-output elements
|
70
70
|
scan_out_1 = Concat<axis=0>(scan_out_1, so_1); ... ; scan_out_k = Concat<axis=0>(scan_out_k, so_k);
|
71
71
|
}
|
72
72
|
return st_1, ..., st_n, scan_out_1, ..., scan_out_k;
|
73
73
|
*Sample usage: Encoding RNN using a Scan*
|
74
74
|
The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi,
|
75
75
|
recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can
|
76
76
|
be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes
|
77
77
|
%Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these
|
78
78
|
values are computed in the outer graph, they need to be passed in as extra state_variables.
|
79
79
|
graph rnn-encoding {
|
80
80
|
%H_0 = ...
|
81
81
|
%X = ...
|
82
82
|
%Y_h, %Y = Scan[body = <graph rnn-cell-1>, num_scan_inputs=1](%H_0, %X)
|
83
83
|
return %Y, %Y_h
|
84
84
|
}
|
85
85
|
graph rnn-cell-1 (
|
86
86
|
%H_tminus1[FLOAT, tensor]
|
87
87
|
%X_t[FLOAT, tensor]
|
88
88
|
) {
|
89
89
|
%Wi = ...
|
90
90
|
%Ri = ...
|
91
91
|
%Wbi = ...
|
92
92
|
%Rbi = ...
|
93
93
|
%t1 = X_t * (Wi^T)
|
94
94
|
%t2 = H_tminus1*(Ri^T)
|
95
95
|
%t3 = Add(%t1, %t2)
|
96
96
|
%t4 = Add(%t3, %Wbi)
|
97
97
|
%t5 = Add(%t4, %Rbi)
|
98
98
|
%Ht = Tanh(%t5)
|
99
99
|
%Accumulate = Identity(%Ht)
|
100
100
|
return %Ht, %Accumulate
|
101
101
|
}
|
102
102
|
**Attributes**
|
103
103
|
* **body** (required):
|
104
104
|
The graph run each iteration. It has N+M inputs: (loop state
|
105
105
|
variables..., scan_input_elts...). It has N+K outputs: (loop state
|
106
106
|
variables..., scan_output_elts...). Each scan_output is created by
|
107
107
|
concatenating the value of the specified scan_output_elt value at
|
108
108
|
the end of each iteration of the loop. It is an error if the
|
109
109
|
dimensions of these values change across loop iterations.
|
110
110
|
* **num_scan_inputs** (required):
|
111
111
|
An attribute specifying the number of scan_inputs M.
|
112
112
|
* **scan_input_axes**:
|
113
113
|
An optional list of M flags. The i-th element of the list specifies
|
114
114
|
the axis to be scanned (the sequence axis) for the i-th scan_input.
|
115
115
|
If omitted, 0 will be used as the scan axis for every scan_input.
|
116
|
-
Negative value for an axis means counting dimensions from the back.
|
117
|
-
Accepted range is [-r, r-1] where r = rank(input).
|
118
116
|
* **scan_input_directions**:
|
119
117
|
An optional list of M flags. The i-th element of the list specifies
|
120
118
|
the direction to be scanned for the i-th scan_input tensor: 0
|
121
119
|
indicates forward direction and 1 indicates reverse direction. If
|
122
120
|
omitted, all scan_input tensors will be scanned in the forward
|
123
121
|
direction.
|
124
122
|
* **scan_output_axes**:
|
125
123
|
An optional list of K flags. The i-th element of the list specifies
|
126
124
|
the axis for the i-th scan_output. The scan outputs are accumulated
|
127
125
|
along the specified axis. If omitted, 0 will be used as the scan
|
126
|
+
axis for every scan_output.
|
128
|
-
axis for every scan_output. Negative value for an axis means
|
129
|
-
counting dimensions from the back. Accepted range is [-r, r-1].
|
130
127
|
* **scan_output_directions**:
|
131
128
|
An optional list of K flags, one for each scan_output. The i-th
|
132
129
|
element of the list specifies whether the i-th scan_output should be
|
133
130
|
constructed by appending or prepending a new value in each
|
134
131
|
iteration: 0 indicates appending and 1 indicates prepending. If
|
135
132
|
omitted, all scan_output tensors will be produced by appending a
|
136
133
|
value in each iteration.
|
137
134
|
**Inputs**
|
138
135
|
Between 1 and 2147483647 inputs.
|
139
136
|
* **initial_state_and_scan_inputs** (variadic) - **V**:
|
140
137
|
Initial values of the loop's N state variables followed by M
|
141
138
|
scan_inputs
|
142
139
|
**Outputs**
|
143
140
|
Between 1 and 2147483647 outputs.
|
144
141
|
* **final_state_and_scan_outputs** (variadic) - **V**:
|
145
142
|
Final values of the loop's N state variables followed by K
|
146
143
|
scan_outputs
|
147
144
|
**Type Constraints**
|
148
145
|
* **V** in (
|
149
146
|
tensor(bool),
|
150
147
|
tensor(complex128),
|
151
148
|
tensor(complex64),
|
152
149
|
tensor(double),
|
153
150
|
tensor(float),
|
154
151
|
tensor(float16),
|
155
152
|
tensor(int16),
|
156
153
|
tensor(int32),
|
157
154
|
tensor(int64),
|
158
155
|
tensor(int8),
|
159
156
|
tensor(string),
|
160
157
|
tensor(uint16),
|
161
158
|
tensor(uint32),
|
162
159
|
tensor(uint64),
|
163
160
|
tensor(uint8)
|
164
161
|
):
|
165
162
|
All Tensor types
|