Xop API#
Most of the converting libraries uses onnx to create ONNX graphs. The API is quite verbose and that is why most of them implement a second API wrapping the first one. They are not necessarily meant to be used by users to create ONNX graphs as they are specialized for the training framework they are developped for.
The API described below is similar to the one implemented in sklearn-onnx but does not depend on it. It be easily moved to a separate package. Xop is the contraction of ONNX Operators.
Short Example#
Let’s say we need to create a graph computed the square loss between two float tensor X and Y.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
# This line creates one class for the operator Sub and Mul.
# It fails if the operators are misspelled.
OnnxSub, OnnxMul = loadop('Sub', 'Mul')
# Inputs are defined by their name as strings.
diff = OnnxSub('X', 'Y')
error = OnnxMul(diff, diff)
# Then we create the ONNX graph defining 'X' and 'Y' as float.
onx = error.to_onnx(numpy.float32, numpy.float32)
# We check it does what it should.
X = numpy.array([4, 5], dtype=numpy.float32)
Y = numpy.array([4.3, 5.7], dtype=numpy.float32)
sess = OnnxInference(onx)
name = sess.output_names
result = sess.run({'X': X, 'Y': Y})
assert_almost_equal((X - Y) ** 2, result[name[0]])
# Finally, we show the content of the graph.
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
input: name='Y' type=dtype('float32') shape=()
Sub(X, Y) -> out_sub_0
Mul(out_sub_0, out_sub_0) -> out_mul_0
output: name='out_mul_0' type=dtype('float32') shape=()
Visually, the model looks like the following.
In the following example, a string such as ‘X’ refers to an input of the graph. Every class Onnx* such as OnnxSub or OnnxMul following the signature implied in ONNX specifications (ONNX Operators). The API supports operators listed here Supported ONNX operators.
Initializers#
Every numpy array defined as an input of an operator is automatically converted into an initializer.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
OnnxSub = loadop('Sub')
# 'X' is an input, the second argument is a constant
# stored as an initializer in the graph.
diff = OnnxSub('X', numpy.array([1], dtype=numpy.float32))
# Then we create the ONNX graph defining 'X' and 'Y' as float.
onx = diff.to_onnx(numpy.float32, numpy.float32)
# We check it does what it should.
X = numpy.array([4, 5], dtype=numpy.float32)
sess = OnnxInference(onx)
name = sess.output_names
result = sess.run({'X': X})
assert_almost_equal(X - 1, result[name[0]])
# Finally, we show the content of the graph.
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
init: name='init' type=dtype('float32') shape=(0,) -- array([1.], dtype=float32)
Sub(X, init) -> out_sub_0
output: name='out_sub_0' type=dtype('float32') shape=()
There are as many initializers as numpy arrays defined in the graph.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
OnnxSub = loadop('Sub')
diff = OnnxSub('X', numpy.array([1], dtype=numpy.float32))
diff2 = OnnxSub(diff, numpy.array([2], dtype=numpy.float32))
onx = diff2.to_onnx(numpy.float32, numpy.float32)
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
init: name='init' type=dtype('float32') shape=(0,) -- array([1.], dtype=float32)
init: name='init_1' type=dtype('float32') shape=(0,) -- array([2.], dtype=float32)
Sub(X, init) -> out_sub_0
Sub(out_sub_0, init_1) -> out_sub_0_1
output: name='out_sub_0_1' type=dtype('float32') shape=()
However, the conversion into onnx then applies function
onnx_optimisations
to remove duplicated initializers. It also removes unnecessary
node such as Identity nodes or unused nodes.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
OnnxSub = loadop('Sub')
diff = OnnxSub('X', numpy.array([1], dtype=numpy.float32))
diff2 = OnnxSub(diff, numpy.array([1], dtype=numpy.float32))
onx = diff2.to_onnx(numpy.float32, numpy.float32)
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
init: name='init' type=dtype('float32') shape=(0,) -- array([1.], dtype=float32)
Sub(X, init) -> out_sub_0
Sub(out_sub_0, init) -> out_sub_0_1
output: name='out_sub_0_1' type=dtype('float32') shape=()
Attributes#
Some operators needs attributes such as operator Transpose. They are defined as named arguments.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
OnnxMatMul, OnnxTranspose = loadop('MatMul', 'Transpose')
# Named attribute perm defines the permutation.
result = OnnxMatMul('X', OnnxTranspose('X', perm=[1, 0]))
onx = result.to_onnx(numpy.float32, numpy.float32)
print(onnx_simple_text_plot(onx))
# discrepancies?
X = numpy.array([[4, 5]], dtype=numpy.float32)
sess = OnnxInference(onx)
name = sess.output_names
result = sess.run({'X': X.copy()})
assert_almost_equal(X @ X.T, result[name[0]])
>>>
opset: domain='' version=13
input: name='X' type=dtype('float32') shape=()
Transpose(X, perm=[1,0]) -> out_tra_0
MatMul(X, out_tra_0) -> out_mat_0
output: name='out_mat_0' type=dtype('float32') shape=()
Operator Cast is used to convert every element of an array into another type. ONNX types and numpy types are different but the API is able to replace one by the correspondance type.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
OnnxCast = loadop('Cast')
result = OnnxCast('X', to=numpy.int64)
onx = result.to_onnx(numpy.float32, numpy.int64)
print(onnx_simple_text_plot(onx))
# discrepancies?
X = numpy.array([[4, 5]], dtype=numpy.float32)
sess = OnnxInference(onx)
name = sess.output_names
result = sess.run({'X': X})
assert_almost_equal(X.astype(numpy.int64), result[name[0]])
>>>
opset: domain='' version=13
input: name='X' type=dtype('float32') shape=()
Cast(X, to=7) -> out_cas_0
output: name='out_cas_0' type=dtype('int64') shape=()
Implicit use of ONNX operators#
ONNX defines standard matrix operator associated to operators +, -, *, /, @. The API implicitely replaces them by the corresponding ONNX operator. In the following example, operator OnnxMatMul was replaced by operator @. The final ONNX graph looks the same.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
OnnxIdentity, OnnxTranspose = loadop('Identity', 'Transpose')
# @ is implicity replaced by OnnxMatMul
result = OnnxIdentity('X') @ OnnxTranspose('X', perm=[1, 0])
onx = result.to_onnx(numpy.float32, numpy.float32)
print(onnx_simple_text_plot(onx))
# discrepancies?
X = numpy.array([[4, 5]], dtype=numpy.float32)
sess = OnnxInference(onx)
name = sess.output_names
result = sess.run({'X': X.copy()})
assert_almost_equal(X @ X.T, result[name[0]])
>>>
opset: domain='' version=16
input: name='X' type=dtype('float32') shape=()
Transpose(X, perm=[1,0]) -> out_tra_0
MatMul(X, out_tra_0) -> out_mat_0
output: name='out_mat_0' type=dtype('float32') shape=()
Operator @ only applies on class OnnxOperator
not on strings.
This is the base class for every class
Identity,
Transpose, …
Operator Identity
is inserted to wrap input ‘X’ and enables the possibility
to use standard operations +, -, *, /, @, >, >=, ==, !=, <, <=, and, or.
Operators with multiple outputs#
Operator TopK returns two results. Accessing one of them requires the use of []. The following example extracts the two greatest elements per rows, uses the positions of them to select the corresponding weight in another matrix, multiply them and returns the average per row.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.npy.xop import loadop
from mlprodict.onnxrt import OnnxInference
OnnxReduceMean, OnnxTopK, OnnxGatherElements = loadop(
'ReduceMean', 'TopK', 'GatherElements')
# @ is implicity replaced by OnnxMatMul
topk = OnnxTopK('X', numpy.array([2], dtype=numpy.int64), axis=1)
dist = OnnxGatherElements('W', topk[1], axis=1)
result = OnnxReduceMean(dist * topk[0], axes=[1])
onx = result.to_onnx(numpy.float32, numpy.float32)
print(onnx_simple_text_plot(onx))
# discrepancies?
X = numpy.array([[4, 5, 6], [7, 0, 1]], dtype=numpy.float32)
W = numpy.array([[1, 0.5, 0.6], [0.5, 0.2, 0.3]], dtype=numpy.float32)
sess = OnnxInference(onx)
name = sess.output_names[0]
result = sess.run({'X': X, 'W': W})
print('\nResults:')
print(result[name])
>>>
opset: domain='' version=13
input: name='W' type=dtype('float32') shape=()
input: name='X' type=dtype('float32') shape=()
init: name='init' type=dtype('int64') shape=(0,) -- array([2])
TopK(X, init, axis=1) -> out_top_0, out_top_1
GatherElements(W, out_top_1, axis=1) -> out_gat_0
Mul(out_gat_0, out_top_0) -> out_mul_0
ReduceMean(out_mul_0, axes=[1]) -> out_red_0
output: name='out_red_0' type=dtype('float32') shape=()
Results:
[[3.05]
[1.9 ]]
Sub Estimators#
It is a common need to insert an ONNX graph into another one.
It is not a simple merge, there are operations before and after
and the ONNX graph may have been produced by another library.
That is the purpose of class OnnxSubOnnx
.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.npy.xop_convert import OnnxSubOnnx
from mlprodict.npy.xop import loadop
from mlprodict.onnxrt import OnnxInference
OnnxIdentity = loadop('Identity')
X = numpy.array([[-1.5, -0.5, 0.5, 1.5]], dtype=numpy.float32)
# Let's create a first ONNX graph which implements
# a Relu function.
vx = OnnxIdentity('X')
sign = vx > numpy.array([0], dtype=numpy.float32)
sign_float = sign.astype(numpy.float32)
relu = vx * sign_float
print('-- Relu graph')
onx_relu = relu.to_onnx(numpy.float32, numpy.float32)
print("\n-- Relu results")
print(onnx_simple_text_plot(onx_relu))
sess = OnnxInference(onx_relu)
name = sess.output_names[0]
result = sess.run({'X': X})
print('\n-- Results:')
print(result[name])
# Then the second graph including the first one.
x_1 = OnnxIdentity('X') + numpy.array([1], dtype=numpy.float32)
# Class OnnxSubOnnx takes a graph as input and applies it on the
# given inputs.
result = OnnxSubOnnx(onx_relu, x_1)
onx = result.to_onnx(numpy.float32, numpy.float32)
print('\n-- Whole graph')
print(onnx_simple_text_plot(onx))
# Expected results?
sess = OnnxInference(onx)
name = sess.output_names[0]
result = sess.run({'X': X})
print('\n-- Whole results:')
print(result[name])
>>>
-- Relu graph
-- Relu results
opset: domain='' version=16
input: name='X' type=dtype('float32') shape=()
init: name='init' type=dtype('float32') shape=(0,) -- array([0.], dtype=float32)
Greater(X, init) -> out_gre_0
Cast(out_gre_0, to=1) -> out_cas_0
Mul(X, out_cas_0) -> out_mul_0
output: name='out_mul_0' type=dtype('float32') shape=()
-- Results:
[[-0. -0. 0.5 1.5]]
-- Whole graph
opset: domain='' version=16
input: name='X' type=dtype('float32') shape=()
init: name='init' type=dtype('float32') shape=(0,) -- array([1.], dtype=float32)
init: name='init_1' type=dtype('float32') shape=(0,) -- array([0.], dtype=float32)
Add(X, init) -> out_add_0
Greater(out_add_0, init_1) -> out_gre_0
Cast(out_gre_0, to=1) -> out_cas_0
Mul(out_add_0, out_cas_0) -> out_sub_0
output: name='out_sub_0' type=dtype('float32') shape=()
-- Whole results:
[[-0. 0.5 1.5 2.5]]
This mechanism is used to plug any model from scikit-learn
converted into ONNX in a bigger graph. Next example averages
the probabilities of two classifiers for a binary classification.
That is the purpose of class OnnxSubEstimator
. The class automatically
calls the appropriate converter, sklearn-onnx for
scikit-learn models.
<<<
import numpy
from numpy.testing import assert_almost_equal
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.npy.xop_convert import OnnxSubEstimator
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
# machine learning part
X, y = make_classification(1000, n_classes=2, n_features=5, n_redundant=0)
X = X.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)
# we train two models not on the same machine
lr1 = LogisticRegression().fit(X_train[:, :2], y_train)
lr2 = LogisticRegression().fit(X_train[:, 2:], y_train)
# score?
p1 = lr1.predict_proba(X_test[:, :2])
print("score1", roc_auc_score(y_test, p1[:, 1]))
p2 = lr2.predict_proba(X_test[:, 2:])
print("score2", roc_auc_score(y_test, p2[:, 1]))
# OnnxGraph
OnnxIdentity, OnnxGather = loadop('Identity', 'Gather')
x1 = OnnxGather('X', numpy.array([0, 1], dtype=numpy.int64), axis=1)
x2 = OnnxGather('X', numpy.array([2, 3, 4], dtype=numpy.int64), axis=1)
# Class OnnxSubEstimator inserts the model into the ONNX graph.
p1 = OnnxSubEstimator(lr1, x1, initial_types=X_train[:, :2])
p2 = OnnxSubEstimator(lr2, x2, initial_types=X_train[:, 2:])
result = ((OnnxIdentity(p1[1]) + OnnxIdentity(p2[1])) /
numpy.array([2], dtype=numpy.float32))
# Then the second graph including the first one.
onx = result.to_onnx(numpy.float32, numpy.float32)
print('\n-- Whole graph')
print(onnx_simple_text_plot(onx))
# Expected results?
sess = OnnxInference(onx)
name = sess.output_names[0]
result = sess.run({'X': X_test})[name]
print("\nscore3", roc_auc_score(y_test, result[:, 1]))
>>>
score1 0.9757051282051282
score2 0.46942307692307705
-- Whole graph
opset: domain='' version=16
input: name='X' type=dtype('float32') shape=()
init: name='init' type=dtype('int64') shape=(0,) -- array([2, 3, 4])
init: name='init_1' type=dtype('int64') shape=(0,) -- array([0, 1])
init: name='init_2' type=dtype('float32') shape=(0,) -- array([2.], dtype=float32)
Gather(X, init, axis=1) -> out_gat_0
LinearClassifier(out_gat_0) -> label, probability_tensor
Normalizer(probability_tensor) -> probabilities
Gather(X, init_1, axis=1) -> out_gat_0_1
LinearClassifier(out_gat_0_1) -> label_1, probability_tensor_1
Normalizer(probability_tensor_1) -> probabilities_1
Add(probabilities_1, probabilities) -> out_add_0
Div(out_add_0, init_2) -> out_div_0
output: name='out_div_0' type=dtype('float32') shape=()
score3 0.975
Inputs, outputs#
The following code does not specify on which type it applies, float32, float64, it could be a tensor of any of numerical type.
<<<
from mlprodict.npy.xop import loadop
OnnxSub, OnnxMul = loadop('Sub', 'Mul')
diff = OnnxSub('X', 'Y')
error = OnnxMul(diff, diff)
print(error)
>>>
OnnxMul(2 in) -> ?
That is why this information must be specified when it is being
converted into ONNX. That explains why method to_onnx
needs more information
to convert the object into ONNX: to_onnx(<input type>, <output type>).
<<<
import numpy
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.npy.xop import loadop
OnnxSub, OnnxMul = loadop('Sub', 'Mul')
diff = OnnxSub('X', 'Y')
error = OnnxMul(diff, diff)
# First numpy.float32 is for the input.
# Second numpy.float32 is for the output.
onx = error.to_onnx(numpy.float32, numpy.float32)
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
input: name='Y' type=dtype('float32') shape=()
Sub(X, Y) -> out_sub_0
Mul(out_sub_0, out_sub_0) -> out_mul_0
output: name='out_mul_0' type=dtype('float32') shape=()
Wrong types are possible however the runtime executing the graph may raise an exception telling the graph cannot be executed.
Optional output type#
Most of the time the output type can be guessed based on the signature of every operator involved in the graph. Second argument, output_type, is optional.
<<<
import numpy
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.npy.xop import loadop
OnnxSub, OnnxMul = loadop('Sub', 'Mul')
diff = OnnxSub('X', 'Y')
error = OnnxMul(diff, diff)
onx = error.to_onnx(numpy.float32)
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
input: name='Y' type=dtype('float32') shape=()
Sub(X, Y) -> out_sub_0
Mul(out_sub_0, out_sub_0) -> out_mul_0
output: name='out_mul_0' type=dtype('float32') shape=()
Multiple inputs and multiple types#
Previous syntax assumes all inputs or outputs share the same type. That is usually the case but not always. The order of inputs is not very clear and that explains why the different types are specifed using a dictionary using name as keys.
<<<
import numpy
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.npy.xop_variable import Variable
from mlprodict.npy.xop import loadop
OnnxMul, OnnxReshape, OnnxReduceSum = loadop(
'Mul', 'Reshape', 'ReduceSum')
diff = OnnxReshape('X', 'Y')
diff2 = OnnxMul(diff, diff)
sumd = OnnxReduceSum(diff2, numpy.array([1], dtype=numpy.int64))
onx = sumd.to_onnx({'X': numpy.float32, 'Y': numpy.int64},
numpy.float32)
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
input: name='Y' type=dtype('int64') shape=()
init: name='init' type=dtype('int64') shape=(0,) -- array([1])
Reshape(X, Y) -> out_res_0
Mul(out_res_0, out_res_0) -> out_mul_0
ReduceSum(out_mul_0, init) -> out_red_0
output: name='out_red_0' type=dtype('float32') shape=()
Specifying output types is more tricky. Types must still be specified by names but output names are unknown. They are decided when the conversion happens unless the user wants them to be named as his wished. That is where argument output_names takes place in the story. It forces method to_onnx to keep the chosen names when the model is converting into ONNX and then we can be sure to give the proper type to the proper output. The two ouputs are coming from two different objects, the conversion is started by calling to_onnx from one and the other one is added in argument other_outputs.
<<<
import numpy
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.npy.xop import loadop
OnnxMul, OnnxReshape, OnnxReduceSum, OnnxShape = loadop(
'Mul', 'Reshape', 'ReduceSum', 'Shape')
diff = OnnxReshape('X', 'Y')
diff2 = OnnxMul(diff, diff)
sumd = OnnxReduceSum(diff2, numpy.array([1], dtype=numpy.int64),
output_names=['Z'])
shape = OnnxShape(sumd, output_names=['S'])
onx = sumd.to_onnx({'X': numpy.float32, 'Y': numpy.int64},
{'Z': numpy.float32, 'S': numpy.int64},
other_outputs=[shape])
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=15
input: name='X' type=dtype('float32') shape=()
input: name='Y' type=dtype('int64') shape=()
init: name='init' type=dtype('int64') shape=(0,) -- array([1])
Reshape(X, Y) -> out_res_0
Mul(out_res_0, out_res_0) -> out_mul_0
ReduceSum(out_mul_0, init) -> Z
Shape(Z) -> S
output: name='Z' type=dtype('float32') shape=()
output: name='S' type=dtype('int64') shape=()
Runtime for ONNX are usually better when inputs and output shapes
are known or at least some part of it. That can be done the following way.
It needs to be done through a list of Variable
.
<<<
import numpy
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.npy.xop_variable import Variable
from mlprodict.npy.xop import loadop
OnnxMul, OnnxReshape, OnnxReduceSum, OnnxShape = loadop(
'Mul', 'Reshape', 'ReduceSum', 'Shape')
diff = OnnxReshape('X', 'Y')
diff2 = OnnxMul(diff, diff)
sumd = OnnxReduceSum(diff2, numpy.array([1], dtype=numpy.int64),
output_names=['Z'])
shape = OnnxShape(sumd, output_names=['S'])
onx = sumd.to_onnx(
[Variable('X', numpy.float32, [None, 2]),
Variable('Y', numpy.int64, [2])],
[Variable('Z', numpy.float32, [None, 1]),
Variable('S', numpy.int64, [2])],
other_outputs=[shape])
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=15
input: name='X' type=dtype('float32') shape=(0, 2)
input: name='Y' type=dtype('int64') shape=(2,)
init: name='init' type=dtype('int64') shape=(0,) -- array([1])
Reshape(X, Y) -> out_res_0
Mul(out_res_0, out_res_0) -> out_mul_0
ReduceSum(out_mul_0, init) -> Z
Shape(Z) -> S
output: name='Z' type=dtype('float32') shape=(0, 1)
output: name='S' type=dtype('int64') shape=(2,)
Opsets#
ONNX is versioned. The assumption is every old ONNX graph must remain valid even if new verions of the language were released. By default, the latest supported version is used. You first have the latest version installed:
<<<
from onnx.defs import onnx_opset_version
print("onnx_opset_version() ->", onnx_opset_version())
>>>
onnx_opset_version() -> 16
But the library does not always support the latest version right away. That is the default opset if none is given.
<<<
import pprint
from mlprodict import __max_supported_opset__, __max_supported_opsets__
print(__max_supported_opset__)
pprint.pprint(__max_supported_opsets__)
>>>
15
{'': 15, 'ai.onnx.ml': 2}
Following example shows how to force the opset to 12 instead of the default version. It must be specified in two places, in every operator, and when calling to_onnx with argument target_opset.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
opset = 12
OnnxSub, OnnxMul = loadop('Sub', 'Mul')
diff = OnnxSub('X', 'Y', op_version=opset)
error = OnnxMul(diff, diff, op_version=opset)
onx = error.to_onnx(numpy.float32, numpy.float32,
target_opset=opset)
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=12
input: name='X' type=dtype('float32') shape=()
input: name='Y' type=dtype('float32') shape=()
Sub(X, Y) -> out_sub_0
Mul(out_sub_0, out_sub_0) -> out_mul_0
output: name='out_mul_0' type=dtype('float32') shape=()
It can be also done by using the specific class corresponding to the most recent version below the considered opset.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
opset = 12
OnnxSub_7, OnnxMul_7 = loadop('Sub_7', 'Mul_7')
diff = OnnxSub_7('X', 'Y')
error = OnnxMul_7(diff, diff)
onx = error.to_onnx(numpy.float32, numpy.float32,
target_opset=opset)
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=7
input: name='X' type=dtype('float32') shape=()
input: name='Y' type=dtype('float32') shape=()
Sub(X, Y) -> out_sub_0
Mul(out_sub_0, out_sub_0) -> out_mul_0
output: name='out_mul_0' type=dtype('float32') shape=()
There is one unique opset per domain. The opsets associated to the other domains can be specified as a dictionary.
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
opset = 12
OnnxSub_7, OnnxMul_7 = loadop('Sub_7', 'Mul_7')
diff = OnnxSub_7('X', 'Y')
error = OnnxMul_7(diff, diff)
onx = error.to_onnx(numpy.float32, numpy.float32,
target_opset={'': opset, 'ai.onnx.ml': 1})
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=7
input: name='X' type=dtype('float32') shape=()
input: name='Y' type=dtype('float32') shape=()
Sub(X, Y) -> out_sub_0
Mul(out_sub_0, out_sub_0) -> out_mul_0
output: name='out_mul_0' type=dtype('float32') shape=()
A last option is available to shorten the expression with operator [].
<<<
import numpy
from numpy.testing import assert_almost_equal
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop import loadop
opset = 12
OnnxSub, OnnxMul = loadop('Sub', 'Mul')
diff = OnnxSub[opset]('X', 'Y')
error = OnnxMul[opset](diff, diff)
onx = error.to_onnx(numpy.float32, numpy.float32,
target_opset=opset)
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=7
input: name='X' type=dtype('float32') shape=()
input: name='Y' type=dtype('float32') shape=()
Sub(X, Y) -> out_sub_0
Mul(out_sub_0, out_sub_0) -> out_mul_0
output: name='out_mul_0' type=dtype('float32') shape=()
Usually, the code written with one opset is likely to run the same way
with the next one. However, the signature of an operator may change,
an attribute may become an input. The code has to be different according
to the opset, see for example function OnnxSqueezeApi11
.
Subgraphs#
Three operators hold graph attributes or subgraph:
If
,
Loop
,
Scan
.
The first one executes one graph or another based on one condition.
The two others ones run loops. Those operators are not so easy
to deal with. Unittests may provide more examples
test_xop.py.
<<<
import numpy
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy.xop_variable import Variable
from mlprodict.npy.xop import loadop
(OnnxSub, OnnxIdentity, OnnxReduceSumSquare, OnnxScan,
OnnxAdd) = loadop('Sub', 'Identity',
'ReduceSumSquare', 'Scan', 'Add')
# Building of the subgraph.
diff = OnnxSub('next_in', 'next')
id_next = OnnxIdentity('next_in', output_names=['next_out'])
flat = OnnxReduceSumSquare(
diff, axes=[1], output_names=['scan_out'], keepdims=0)
scan_body = id_next.to_onnx(
[Variable('next_in', numpy.float32, (None, None)),
Variable('next', numpy.float32, (None, ))],
outputs=[Variable('next_out', numpy.float32, (None, None)),
Variable('scan_out', numpy.float32, (None, ))],
other_outputs=[flat])
output_names = [o.name for o in scan_body.graph.output]
cop = OnnxAdd('input', 'input')
# Subgraph as a graph attribute.
node = OnnxScan(cop, cop, output_names=['S1', 'S2'],
num_scan_inputs=1,
body=(scan_body.graph, [id_next, flat]))
cop2 = OnnxIdentity(node[1], output_names=['cdist'])
model_def = cop2.to_onnx(numpy.float32, numpy.float32)
x = numpy.array([1, 2, 4, 5, 5, 4]).astype(
numpy.float32).reshape((3, 2))
sess = OnnxInference(model_def)
res = sess.run({'input': x})
print(res)
print("\n-- Graph:")
print(onnx_simple_text_plot(model_def, recursive=True))
>>>
{'cdist': array([[ 0., 72., 80.],
[72., 0., 8.],
[80., 8., 0.]], dtype=float32)}
-- Graph:
opset: domain='' version=16
input: name='input' type=dtype('float32') shape=()
Add(input, input) -> out_add_0
Scan(out_add_0, out_add_0, num_scan_inputs=1) -> S1, cdist
output: name='cdist' type=dtype('float32') shape=()
----- subgraph ---- Scan - _scan - att.body=
input: name='next_in' type=dtype('float32') shape=(0, 0)
input: name='next' type=dtype('float32') shape=(0,)
Identity(next_in) -> next_out
Sub(next_in, next) -> out_sub_0
ReduceSumSquare(out_sub_0, axes=[1], keepdims=0) -> scan_out
output: name='next_out' type=dtype('float32') shape=(0, 0)
output: name='scan_out' type=dtype('float32') shape=(0,)
And visually:
Function or Graph#
There are two ways to export a onnx graph, as a full graph with typed inputs and outputs or as a function with named inputs. First one works as described in the previous examples. The second one is enabled by using parameter function_name and function_domain. They trigger the conversion to a function as shown in the following example.
<<<
from mlprodict.npy.xop import loadop
OnnxAbs, OnnxAdd = loadop("Abs", "Add")
ov = OnnxAbs('X')
ad = OnnxAdd('X', ov, output_names=['Y'])
proto = ad.to_onnx(function_name='AddAbs')
print(proto)
>>>
name: "AddAbs"
input: "X"
output: "Y"
node {
input: "X"
output: "out_abs_0"
name: "_abs"
op_type: "Abs"
domain: ""
}
node {
input: "X"
input: "out_abs_0"
output: "Y"
name: "_add"
op_type: "Add"
domain: ""
}
opset_import {
domain: ""
version: 14
}
domain: "mlprodict"
Input and output types are not defined and the function is valid
for whichever type works the code of the function. This function
can now be used inside a bigger graph with class
OnnxOperatorFunction
.
<<<
import numpy
from mlprodict.npy.xop import loadop, OnnxOperatorFunction
from mlprodict.plotting.text_plot import onnx_simple_text_plot
OnnxAbs, OnnxAdd, OnnxDiv = loadop("Abs", "Add", "Div")
# the function
ov = OnnxAbs('X')
ad = OnnxAdd('X', ov, output_names=['Y'])
proto = ad.to_onnx(function_name='AddAbs')
# used in graph with operator OnnxOperatorFunction
op = OnnxDiv(OnnxOperatorFunction(proto, 'X'),
numpy.array([2], dtype=numpy.float32),
output_names=['Y'])
# display
onx = op.to_onnx(numpy.float32, numpy.float32)
print(onnx_simple_text_plot(onx))
>>>
opset: domain='mlprodict' version=1
opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
init: name='init' type=dtype('float32') shape=(0,) -- array([2.], dtype=float32)
AddAbs[mlprodict](X) -> out_fun_0
Div(out_fun_0, init) -> Y
output: name='Y' type=dtype('float32') shape=()
----- function name=AddAbs domain=mlprodict
opset: domain='' version=14
input: 'X'
Abs(X) -> out_abs_0
Add(X, out_abs_0) -> Y
output: name='Y' type=? shape=?
The same syntax can be simplified with an implicit conversion of an ONNX graph with ad(‘X’). ‘A’ is the input of a function, ‘X’ is the tensor the function is applied to.
<<<
import numpy
from mlprodict.npy.xop import loadop, OnnxOperatorFunction
from mlprodict.plotting.text_plot import onnx_simple_text_plot
OnnxAbs, OnnxAdd, OnnxDiv = loadop("Abs", "Add", "Div")
# the function
ov = OnnxAbs('A')
ad = OnnxAdd('A', ov)
# used in graph
op = OnnxDiv(ad('X'), numpy.array([2], dtype=numpy.float32),
output_names=['Y'])
# display
onx = op.to_onnx(numpy.float32, numpy.float32)
print(onnx_simple_text_plot(onx))
>>>
opset: domain='mlprodict' version=1
opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
init: name='init' type=dtype('float32') shape=(0,) -- array([2.], dtype=float32)
AddAbs[mlprodict](X) -> out_fun_0
Div(out_fun_0, init) -> Y
output: name='Y' type=dtype('float32') shape=()
----- function name=AddAbs domain=mlprodict
opset: domain='' version=14
input: 'A'
Abs(A) -> out_abs_0
Add(A, out_abs_0) -> out_add_0
output: name='out_add_0' type=? shape=?
Eager evaluation#
It is not easy to check the ONNX function returns the expected result
only at the end of it. It is very useful to check that the function
goes through expected transformations all along the graph.
The can be done with method OnnxOperator.f
.
The method independently runs every node in the graph after it was
converted into ONNX.
<<<
import numpy
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.npy.xop import loadop
X = numpy.array([[4, 5, 6], [7, 0, 1]], dtype=numpy.float32)
W = numpy.array([[1, 0.5, 0.6], [0.5, 0.2, 0.3]], dtype=numpy.float32)
OnnxReduceMean, OnnxTopK, OnnxGatherElements = loadop(
'ReduceMean', 'TopK', 'GatherElements')
topk = OnnxTopK('X', numpy.array([2], dtype=numpy.int64), axis=1)
dist = OnnxGatherElements('W', topk[1], axis=1)
print(dist.f({'X': X, 'W': W}))
# It is possible to simplify this expression into:
print("expected order:", dist.find_named_inputs())
print(dist.f(W, X))
result = OnnxReduceMean(dist * topk[0], axes=[1])
onx = result.to_onnx(numpy.float32, numpy.float32)
print(onnx_simple_text_plot(onx))
>>>
{'output0': array([[0.6, 0.5],
[0.5, 0.3]], dtype=float32)}
expected order: ['W', 'X']
[[0.6 0.5]
[0.5 0.3]]
opset: domain='' version=13
input: name='W' type=dtype('float32') shape=()
input: name='X' type=dtype('float32') shape=()
init: name='init' type=dtype('int64') shape=(0,) -- array([2])
TopK(X, init, axis=1) -> out_top_0, out_top_1
GatherElements(W, out_top_1, axis=1) -> out_gat_0
Mul(out_gat_0, out_top_0) -> out_mul_0
ReduceMean(out_mul_0, axes=[1]) -> out_red_0
output: name='out_red_0' type=dtype('float32') shape=()