Where

Where - 16

Version

  • name: Where (GitHub)

  • domain: main

  • since_version: 16

  • function: False

  • support_level: SupportType.COMMON

  • shape inference: True

This version of the operator has been available since version 16.

Summary

Return elements, either from X or Y, depending on condition. Where behaves like [numpy.where](https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html) with three parameters.

This operator supports multidirectional (i.e., Numpy-style) broadcasting; for more details please check Broadcasting in ONNX.

History - Version 16 adds bfloat16 to the types allowed (for the second and third parameter).

Inputs

  • condition (heterogeneous) - B: When True (nonzero), yield X, otherwise yield Y

  • X (heterogeneous) - T: values selected at indices where condition is True

  • Y (heterogeneous) - T: values selected at indices where condition is False

Outputs

  • output (heterogeneous) - T: Tensor of shape equal to the broadcasted shape of condition, X, and Y.

Type Constraints

  • B in ( tensor(bool) ): Constrain to boolean tensors.

  • T in ( tensor(bfloat16), tensor(bool), tensor(complex128), tensor(complex64), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) ): Constrain input and output types to all tensor types (including bfloat).

Examples

default

import numpy as np
import onnx

node = onnx.helper.make_node(
    "Where",
    inputs=["condition", "x", "y"],
    outputs=["z"],
)

condition = np.array([[1, 0], [1, 1]], dtype=bool)
x = np.array([[1, 2], [3, 4]], dtype=np.float32)
y = np.array([[9, 8], [7, 6]], dtype=np.float32)
z = np.where(condition, x, y)  # expected output [[1, 8], [3, 4]]
expect(node, inputs=[condition, x, y], outputs=[z], name="test_where_example")

_long

import numpy as np
import onnx

node = onnx.helper.make_node(
    "Where",
    inputs=["condition", "x", "y"],
    outputs=["z"],
)

condition = np.array([[1, 0], [1, 1]], dtype=bool)
x = np.array([[1, 2], [3, 4]], dtype=np.int64)
y = np.array([[9, 8], [7, 6]], dtype=np.int64)
z = np.where(condition, x, y)  # expected output [[1, 8], [3, 4]]
expect(
    node, inputs=[condition, x, y], outputs=[z], name="test_where_long_example"
)

Where - 9

Version

  • name: Where (GitHub)

  • domain: main

  • since_version: 9

  • function: False

  • support_level: SupportType.COMMON

  • shape inference: True

This version of the operator has been available since version 9.

Summary

Return elements, either from X or Y, depending on condition. Where behaves like [numpy.where](https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html) with three parameters.

This operator supports multidirectional (i.e., Numpy-style) broadcasting; for more details please check Broadcasting in ONNX.

Inputs

  • condition (heterogeneous) - B: When True (nonzero), yield X, otherwise yield Y

  • X (heterogeneous) - T: values selected at indices where condition is True

  • Y (heterogeneous) - T: values selected at indices where condition is False

Outputs

  • output (heterogeneous) - T: Tensor of shape equal to the broadcasted shape of condition, X, and Y.

Type Constraints

  • B in ( tensor(bool) ): Constrain to boolean tensors.

  • T in ( tensor(bool), tensor(complex128), tensor(complex64), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) ): Constrain input and output types to all tensor types.