JAX model#
This page answers common “how do I…” questions for converting a
jax model to ONNX with yobx.tensorflow.to_onnx().
How to convert a JAX model#
Write the model as a callable and pass a representative dummy input to
yobx.tensorflow.to_onnx():
<<<
import jax
import numpy as np
from yobx.tensorflow import to_onnx
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
W1 = jax.random.normal(k1, (8, 16), dtype=np.float32)
b1 = np.zeros(16, dtype=np.float32)
W2 = jax.random.normal(k2, (16, 4), dtype=np.float32)
b2 = np.zeros(4, dtype=np.float32)
def jax_mlp(x):
h = jax.nn.relu(x @ W1 + b1)
return h @ W2 + b2
X = np.random.default_rng(0).standard_normal((10, 8)).astype(np.float32)
onx = to_onnx(jax_mlp, (X,))
print(f"nodes={len(onx.graph.node)}")
>>>
/home/runner/work/xadupre.github.io/xadupre.github.io/yet-another-onnx-builder/yobx/tensorflow/tensorflow_helper.py:123: DeprecationWarning: The `native_serialization` parameter is deprecated and will be removed in a future version of JAX.
tf_fn = jax2tf.convert(
/home/runner/work/xadupre.github.io/xadupre.github.io/yet-another-onnx-builder/yobx/tensorflow/tensorflow_helper.py:123: DeprecationWarning: The `enable_xla` parameter is deprecated and will be removed in a future version of JAX.
tf_fn = jax2tf.convert(
nodes=3
How to verify ONNX outputs against JAX#
Run the ONNX model with onnxruntime and compare against JAX outputs:
<<<
import jax
import numpy as np
import onnxruntime
from yobx.tensorflow import to_onnx
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
W1 = jax.random.normal(k1, (8, 16), dtype=np.float32)
b1 = np.zeros(16, dtype=np.float32)
W2 = jax.random.normal(k2, (16, 4), dtype=np.float32)
b2 = np.zeros(4, dtype=np.float32)
def jax_mlp(x):
h = jax.nn.relu(x @ W1 + b1)
return h @ W2 + b2
X = np.random.default_rng(0).standard_normal((10, 8)).astype(np.float32)
onx = to_onnx(jax_mlp, (X,))
sess = onnxruntime.InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
input_name = sess.get_inputs()[0].name
(got,) = sess.run(None, {input_name: X})
expected = np.asarray(jax_mlp(X))
np.testing.assert_allclose(expected, got, atol=1e-2)
print("Outputs match ✓")
>>>
/home/runner/work/xadupre.github.io/xadupre.github.io/yet-another-onnx-builder/yobx/tensorflow/tensorflow_helper.py:123: DeprecationWarning: The `native_serialization` parameter is deprecated and will be removed in a future version of JAX.
tf_fn = jax2tf.convert(
/home/runner/work/xadupre.github.io/xadupre.github.io/yet-another-onnx-builder/yobx/tensorflow/tensorflow_helper.py:123: DeprecationWarning: The `enable_xla` parameter is deprecated and will be removed in a future version of JAX.
tf_fn = jax2tf.convert(
Outputs match ✓
How to export with a dynamic batch dimension#
Use dynamic_shapes to name the batch dimension:
<<<
import jax
import numpy as np
from yobx.tensorflow import to_onnx
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
W1 = jax.random.normal(k1, (8, 16), dtype=np.float32)
b1 = np.zeros(16, dtype=np.float32)
W2 = jax.random.normal(k2, (16, 4), dtype=np.float32)
b2 = np.zeros(4, dtype=np.float32)
def jax_mlp(x):
h = jax.nn.relu(x @ W1 + b1)
return h @ W2 + b2
X = np.random.default_rng(0).standard_normal((10, 8)).astype(np.float32)
onx_dyn = to_onnx(jax_mlp, (X,), dynamic_shapes=({0: "batch"},))
batch_dim = onx_dyn.graph.input[0].type.tensor_type.shape.dim[0]
assert batch_dim.dim_param
print("dynamic batch dimension:", batch_dim.dim_param)
>>>
/home/runner/work/xadupre.github.io/xadupre.github.io/yet-another-onnx-builder/yobx/tensorflow/tensorflow_helper.py:123: DeprecationWarning: The `native_serialization` parameter is deprecated and will be removed in a future version of JAX.
tf_fn = jax2tf.convert(
/home/runner/work/xadupre.github.io/xadupre.github.io/yet-another-onnx-builder/yobx/tensorflow/tensorflow_helper.py:123: DeprecationWarning: The `enable_xla` parameter is deprecated and will be removed in a future version of JAX.
tf_fn = jax2tf.convert(
dynamic batch dimension: dim
How to use jax_to_concrete_function explicitly#
If needed, convert JAX to a TensorFlow concrete function first:
<<<
import jax
import numpy as np
from yobx.tensorflow import to_onnx
from yobx.tensorflow.tensorflow_helper import jax_to_concrete_function
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
W1 = jax.random.normal(k1, (8, 16), dtype=np.float32)
b1 = np.zeros(16, dtype=np.float32)
W2 = jax.random.normal(k2, (16, 4), dtype=np.float32)
b2 = np.zeros(4, dtype=np.float32)
def jax_mlp(x):
h = jax.nn.relu(x @ W1 + b1)
return h @ W2 + b2
X = np.random.default_rng(0).standard_normal((10, 8)).astype(np.float32)
cf = jax_to_concrete_function(jax_mlp, (X,), dynamic_shapes=({0: "batch"},))
onx = to_onnx(cf, (X,), dynamic_shapes=({0: "batch"},))
print(f"nodes={len(onx.graph.node)}")
>>>
/home/runner/work/xadupre.github.io/xadupre.github.io/yet-another-onnx-builder/yobx/tensorflow/tensorflow_helper.py:123: DeprecationWarning: The `native_serialization` parameter is deprecated and will be removed in a future version of JAX.
tf_fn = jax2tf.convert(
/home/runner/work/xadupre.github.io/xadupre.github.io/yet-another-onnx-builder/yobx/tensorflow/tensorflow_helper.py:123: DeprecationWarning: The `enable_xla` parameter is deprecated and will be removed in a future version of JAX.
tf_fn = jax2tf.convert(
nodes=3
See also
Converting a JAX function to ONNX — full runnable JAX to ONNX gallery example.