Where#
Where - 16#
Version
name: Where (GitHub)
domain: main
since_version: 16
function: False
support_level: SupportType.COMMON
shape inference: True
This version of the operator has been available since version 16.
Summary
Return elements, either from X or Y, depending on condition. Where behaves like [numpy.where](https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html) with three parameters.
This operator supports multidirectional (i.e., Numpy-style) broadcasting; for more details please check Broadcasting in ONNX.
History - Version 16 adds bfloat16 to the types allowed (for the second and third parameter).
Inputs
condition (heterogeneous) - B: When True (nonzero), yield X, otherwise yield Y
X (heterogeneous) - T: values selected at indices where condition is True
Y (heterogeneous) - T: values selected at indices where condition is False
Outputs
output (heterogeneous) - T: Tensor of shape equal to the broadcasted shape of condition, X, and Y.
Type Constraints
B in ( tensor(bool) ): Constrain to boolean tensors.
T in ( tensor(bfloat16), tensor(bool), tensor(complex128), tensor(complex64), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) ): Constrain input and output types to all tensor types (including bfloat).
Examples
default
import numpy as np
import onnx
node = onnx.helper.make_node(
"Where",
inputs=["condition", "x", "y"],
outputs=["z"],
)
condition = np.array([[1, 0], [1, 1]], dtype=bool)
x = np.array([[1, 2], [3, 4]], dtype=np.float32)
y = np.array([[9, 8], [7, 6]], dtype=np.float32)
z = np.where(condition, x, y) # expected output [[1, 8], [3, 4]]
expect(node, inputs=[condition, x, y], outputs=[z], name="test_where_example")
_long
import numpy as np
import onnx
node = onnx.helper.make_node(
"Where",
inputs=["condition", "x", "y"],
outputs=["z"],
)
condition = np.array([[1, 0], [1, 1]], dtype=bool)
x = np.array([[1, 2], [3, 4]], dtype=np.int64)
y = np.array([[9, 8], [7, 6]], dtype=np.int64)
z = np.where(condition, x, y) # expected output [[1, 8], [3, 4]]
expect(
node, inputs=[condition, x, y], outputs=[z], name="test_where_long_example"
)
Where - 9#
Version
name: Where (GitHub)
domain: main
since_version: 9
function: False
support_level: SupportType.COMMON
shape inference: True
This version of the operator has been available since version 9.
Summary
Return elements, either from X or Y, depending on condition. Where behaves like [numpy.where](https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html) with three parameters.
This operator supports multidirectional (i.e., Numpy-style) broadcasting; for more details please check Broadcasting in ONNX.
Inputs
condition (heterogeneous) - B: When True (nonzero), yield X, otherwise yield Y
X (heterogeneous) - T: values selected at indices where condition is True
Y (heterogeneous) - T: values selected at indices where condition is False
Outputs
output (heterogeneous) - T: Tensor of shape equal to the broadcasted shape of condition, X, and Y.
Type Constraints
B in ( tensor(bool) ): Constrain to boolean tensors.
T in ( tensor(bool), tensor(complex128), tensor(complex64), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) ): Constrain input and output types to all tensor types.