.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_tensorflow/plot_jax_to_onnx.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_tensorflow_plot_jax_to_onnx.py: .. _l-plot-jax-to-onnx: Converting a JAX function to ONNX ================================== :func:`yobx.tensorflow.to_onnx` can also convert :epkg:`JAX` functions to ONNX. Under the hood it uses :func:`jax.experimental.jax2tf.convert` to lower the JAX computation to a :class:`tensorflow.ConcreteFunction` and then applies the same TF→ONNX conversion pipeline used for Keras models. Alternatively, :func:`yobx.tensorflow.tensorflow_helper.jax_to_concrete_function` can be called explicitly to obtain the intermediate :class:`~tensorflow.ConcreteFunction` before passing it to :func:`~yobx.tensorflow.to_onnx`. The workflow is: 1. **Write** a plain JAX function (or wrap a :mod:`flax`/:mod:`equinox` model in a function). 2. Call :func:`yobx.tensorflow.to_onnx` with a representative *dummy input*. The converter detects that the callable is a JAX function and automatically routes it through :func:`~yobx.tensorflow.tensorflow_helper.jax_to_concrete_function`. 3. **Run** the exported ONNX model with any ONNX runtime — this example uses :epkg:`onnxruntime`. 4. **Verify** that the ONNX outputs match JAX's own outputs. .. GENERATED FROM PYTHON SOURCE LINES 30-40 .. code-block:: Python import jax import jax.numpy as jnp import numpy as np import onnxruntime from yobx.doc import plot_dot from yobx.helpers import max_diff from yobx.helpers.onnx_helper import pretty_onnx from yobx.tensorflow import to_onnx from yobx.tensorflow.tensorflow_helper import jax_to_concrete_function .. GENERATED FROM PYTHON SOURCE LINES 41-47 1. Simple element-wise function -------------------------------- We start with the simplest possible JAX function: an element-wise ``sin`` applied to a float32 matrix. :func:`to_onnx` auto-detects that the callable is a JAX function and converts it transparently. .. GENERATED FROM PYTHON SOURCE LINES 47-62 .. code-block:: Python rng = np.random.default_rng(0) X = rng.standard_normal((5, 4)).astype(np.float32) def jax_sin(x): return jnp.sin(x) onx_sin = to_onnx(jax_sin, (X,)) print("Opset :", onx_sin.opset_import[0].version) print("Number of nodes :", len(onx_sin.graph.node)) print("Node op-types :", [n.op_type for n in onx_sin.graph.node]) .. rst-class:: sphx-glr-script-out .. code-block:: none Opset : 21 Number of nodes : 1 Node op-types : ['Sin'] .. GENERATED FROM PYTHON SOURCE LINES 63-67 Run and compare ~~~~~~~~~~~~~~~~ Verify that the ONNX model reproduces the JAX output. .. GENERATED FROM PYTHON SOURCE LINES 67-80 .. code-block:: Python ref_sin = onnxruntime.InferenceSession( onx_sin.SerializeToString(), providers=["CPUExecutionProvider"] ) input_name = ref_sin.get_inputs()[0].name (result_sin,) = ref_sin.run(None, {input_name: X}) expected_sin = np.asarray(jax_sin(X)) print("\nJAX output (first row):", expected_sin[0]) print("ONNX output (first row):", result_sin[0]) assert np.allclose(expected_sin, result_sin, atol=1e-5), "Mismatch!" print("Outputs match ✓ - ", max_diff(expected_sin, result_sin)) .. rst-class:: sphx-glr-script-out .. code-block:: none JAX output (first row): [ 0.12539922 -0.13172095 0.59753436 0.10470783] ONNX output (first row): [ 0.12539922 -0.13172095 0.5975344 0.10470783] Outputs match ✓ - {'abs': 5.960464477539063e-08, 'rel': 1.1654808964571742e-07, 'sum': 1.7881393432617188e-07, 'n': 20.0, 'dnan': 0.0, 'argm': (0, 2)} .. GENERATED FROM PYTHON SOURCE LINES 81-86 2. Multi-layer MLP in JAX -------------------------- A slightly more complex function: a two-layer MLP with ReLU activations whose weights are stored as JAX arrays captured in a closure. .. GENERATED FROM PYTHON SOURCE LINES 86-108 .. code-block:: Python key = jax.random.PRNGKey(42) k1, k2 = jax.random.split(key) W1 = jax.random.normal(k1, (8, 16), dtype=np.float32) b1 = np.zeros(16, dtype=np.float32) W2 = jax.random.normal(k2, (16, 4), dtype=np.float32) b2 = np.zeros(4, dtype=np.float32) def jax_mlp(x): h = jax.nn.relu(x @ W1 + b1) return h @ W2 + b2 X_mlp = rng.standard_normal((10, 8)).astype(np.float32) onx_mlp = to_onnx(jax_mlp, (X_mlp,)) op_types = [n.op_type for n in onx_mlp.graph.node] print("\nOp-types in the MLP graph:", op_types) assert "MatMul" in op_types .. rst-class:: sphx-glr-script-out .. code-block:: none Op-types in the MLP graph: ['MatMul', 'Relu', 'MatMul'] .. GENERATED FROM PYTHON SOURCE LINES 109-110 Display the model. .. GENERATED FROM PYTHON SOURCE LINES 110-112 .. code-block:: Python print(pretty_onnx(onx_mlp)) .. rst-class:: sphx-glr-script-out .. code-block:: none opset: domain='' version=21 input: name='X:0' type=dtype('float32') shape=['dim', 8] init: name='%6' type=float32 shape=(8, 16) init: name='%8' type=float32 shape=(16, 4) MatMul(X:0, %6) -> _onx_matmul_jax2tf_arg_0:0 Relu(_onx_matmul_jax2tf_arg_0:0) -> _onx_max_add_matmul_jax2tf_arg_0:0 MatMul(_onx_max_add_matmul_jax2tf_arg_0:0, %8) -> Identity:0 output: name='Identity:0' type='NOTENSOR' shape=None .. GENERATED FROM PYTHON SOURCE LINES 113-114 Verify predictions on a held-out batch. .. GENERATED FROM PYTHON SOURCE LINES 114-125 .. code-block:: Python ref_mlp = onnxruntime.InferenceSession( onx_mlp.SerializeToString(), providers=["CPUExecutionProvider"] ) input_name_mlp = ref_mlp.get_inputs()[0].name (result_mlp,) = ref_mlp.run(None, {input_name_mlp: X_mlp}) expected_mlp = np.asarray(jax_mlp(X_mlp)) np.testing.assert_allclose(expected_mlp, result_mlp, atol=1e-2) print("MLP predictions match ✓ - ", max_diff(expected_mlp, result_mlp)) .. rst-class:: sphx-glr-script-out .. code-block:: none MLP predictions match ✓ - {'abs': 1.9073486328125e-06, 'rel': 1.1398523971026608e-05, 'sum': 1.7836689949035645e-05, 'n': 40.0, 'dnan': 0.0, 'argm': (2, 3)} .. GENERATED FROM PYTHON SOURCE LINES 126-131 3. Dynamic batch dimension --------------------------- By default :func:`to_onnx` marks axis 0 as a dynamic (symbolic) batch dimension. The converted model runs correctly for any batch size. .. GENERATED FROM PYTHON SOURCE LINES 131-150 .. code-block:: Python onx_dyn = to_onnx(jax_mlp, (X_mlp,), dynamic_shapes=({0: "batch"},)) input_shape = onx_dyn.graph.input[0].type.tensor_type.shape batch_dim = input_shape.dim[0] print("\nBatch dimension param :", batch_dim.dim_param) assert batch_dim.dim_param, "Expected a named dynamic dimension" ref_dyn = onnxruntime.InferenceSession( onx_dyn.SerializeToString(), providers=["CPUExecutionProvider"] ) input_name_dyn = ref_dyn.get_inputs()[0].name for n in (1, 7, 20): X_batch = rng.standard_normal((n, 8)).astype(np.float32) (out,) = ref_dyn.run(None, {input_name_dyn: X_batch}) expected = np.asarray(jax_mlp(X_batch)) np.testing.assert_allclose(expected, out, atol=1e-2) print(f"Dynamic-batch model verified for batch sizes {n} ✓ - ", max_diff(expected, out)) .. rst-class:: sphx-glr-script-out .. code-block:: none Batch dimension param : dim Dynamic-batch model verified for batch sizes 1 ✓ - {'abs': 0.0, 'rel': 0.0, 'sum': 0.0, 'n': 4.0, 'dnan': 0.0, 'argm': (0, 0)} Dynamic-batch model verified for batch sizes 7 ✓ - {'abs': 1.9073486328125e-06, 'rel': 1.142880618713509e-05, 'sum': 1.0987743735313416e-05, 'n': 28.0, 'dnan': 0.0, 'argm': (1, 2)} Dynamic-batch model verified for batch sizes 20 ✓ - {'abs': 1.9073486328125e-06, 'rel': 4.7222896414374055e-05, 'sum': 3.582797944545746e-05, 'n': 80.0, 'dnan': 0.0, 'argm': (7, 2)} .. GENERATED FROM PYTHON SOURCE LINES 151-157 4. Explicit jax_to_concrete_function --------------------------------------- :func:`~yobx.tensorflow.tensorflow_helper.jax_to_concrete_function` can be called directly when you want to inspect or reuse the intermediate :class:`~tensorflow.ConcreteFunction` before exporting to ONNX. .. GENERATED FROM PYTHON SOURCE LINES 157-178 .. code-block:: Python def jax_softmax(x): return jax.nn.softmax(x, axis=-1) X_cls = rng.standard_normal((6, 10)).astype(np.float32) cf = jax_to_concrete_function(jax_softmax, (X_cls,), dynamic_shapes=({0: "batch"},)) onx_cls = to_onnx(cf, (X_cls,), dynamic_shapes=({0: "batch"},)) ref_cls = onnxruntime.InferenceSession( onx_cls.SerializeToString(), providers=["CPUExecutionProvider"] ) input_name_cls = ref_cls.get_inputs()[0].name (result_cls,) = ref_cls.run(None, {input_name_cls: X_cls}) expected_cls = np.asarray(jax_softmax(X_cls)) assert np.allclose(expected_cls, result_cls, atol=1e-5), "Softmax mismatch!" print("Explicit jax_to_concrete_function verified ✓ - ", max_diff(expected_cls, result_cls)) .. rst-class:: sphx-glr-script-out .. code-block:: none Explicit jax_to_concrete_function verified ✓ - {'abs': 2.9802322387695312e-08, 'rel': 1.7332412235332218e-07, 'sum': 4.4051557779312134e-07, 'n': 60.0, 'dnan': 0.0, 'argm': (2, 7)} .. GENERATED FROM PYTHON SOURCE LINES 179-182 5. Visualize the ONNX graph ---------------------------- .. GENERATED FROM PYTHON SOURCE LINES 182-183 .. code-block:: Python plot_dot(onx_mlp) .. image-sg:: /auto_examples_tensorflow/images/sphx_glr_plot_jax_to_onnx_001.png :alt: plot jax to onnx :srcset: /auto_examples_tensorflow/images/sphx_glr_plot_jax_to_onnx_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 6.067 seconds) .. _sphx_glr_download_auto_examples_tensorflow_plot_jax_to_onnx.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_jax_to_onnx.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_jax_to_onnx.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_jax_to_onnx.zip ` .. include:: plot_jax_to_onnx.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_