{"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Precision loss due to float32 conversion with ONNX\n", "\n", "The notebook studies the loss of precision while converting a non-continuous model into float32. It studies the conversion of [GradientBoostingClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html) and then a [DecisionTreeRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html) for which a runtime supported float64 was implemented."]}, {"cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [{"data": {"text/html": ["
run previous cell, wait for 2 seconds
\n", ""], "text/plain": [""]}, "execution_count": 2, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import add_notebook_menu\n", "add_notebook_menu()"]}, {"cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": ["%matplotlib inline"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## GradientBoostingClassifier\n", "\n", "We just train such a model on Iris dataset."]}, {"cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": ["from sklearn.datasets import load_iris\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import GradientBoostingClassifier"]}, {"cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [{"data": {"text/plain": ["GradientBoostingClassifier(n_estimators=20)"]}, "execution_count": 5, "metadata": {}, "output_type": "execute_result"}], "source": ["iris = load_iris()\n", "X, y = iris.data, iris.target\n", "X_train, X_test, y_train, _ = train_test_split(\n", " X, y, random_state=1, shuffle=True)\n", "clr = GradientBoostingClassifier(n_estimators=20)\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["We are interested into the probability of the last class."]}, {"cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [{"data": {"text/plain": ["array([0.03010582, 0.03267555, 0.03267424, 0.03010582, 0.94383517,\n", " 0.02866979, 0.94572751, 0.03010582, 0.03010582, 0.94383517,\n", " 0.03267555, 0.03010582, 0.94696795, 0.0317053 , 0.03267555,\n", " 0.03010582, 0.03267555, 0.03267555, 0.03010582, 0.03010582,\n", " 0.03267555, 0.03267555, 0.94577389, 0.03010582, 0.91161635,\n", " 0.03267555, 0.03010582, 0.03010582, 0.03267424, 0.94282974,\n", " 0.03267424, 0.94696795, 0.03267555, 0.94696795, 0.9387834 ,\n", " 0.03010582, 0.03267555, 0.03010582])"]}, "execution_count": 6, "metadata": {}, "output_type": "execute_result"}], "source": ["exp = clr.predict_proba(X_test)[:, 2]\n", "exp"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## Conversion to ONNX and comparison to original outputs"]}, {"cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": ["import numpy\n", "from mlprodict.onnxrt import OnnxInference\n", "from mlprodict.onnx_conv import to_onnx"]}, {"cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [{"data": {"text/plain": ["{'output_label': array([0, 1, 1, 0, 2, 1, 2, 0, 0, 2, 1, 0, 2, 1, 1, 0, 1, 1, 0, 0, 1, 1,\n", " 2, 0, 2, 1, 0, 0, 1, 2, 1, 2, 1, 2, 2, 0, 1, 0], dtype=int64),\n", " 'output_probability': [{0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.02932842, 1: 0.9379961, 2: 0.032675553},\n", " {0: 0.029367255, 1: 0.93795854, 2: 0.032674246},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.026494453, 1: 0.02967037, 2: 0.9438352},\n", " {0: 0.027988827, 1: 0.94334143, 2: 0.028669795},\n", " {0: 0.026551371, 1: 0.027721122, 2: 0.9457275},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.026494453, 1: 0.02967037, 2: 0.9438352},\n", " {0: 0.02932842, 1: 0.9379961, 2: 0.032675553},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.026586197, 1: 0.026445853, 2: 0.946968},\n", " {0: 0.027929045, 1: 0.9403657, 2: 0.0317053},\n", " {0: 0.02932842, 1: 0.9379961, 2: 0.032675553},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.02932842, 1: 0.9379961, 2: 0.032675553},\n", " {0: 0.02932842, 1: 0.9379961, 2: 0.032675553},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.02932842, 1: 0.9379961, 2: 0.032675553},\n", " {0: 0.02932842, 1: 0.9379961, 2: 0.032675553},\n", " {0: 0.026503632, 1: 0.027722482, 2: 0.9457739},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.041209597, 1: 0.04717405, 2: 0.9116163},\n", " {0: 0.02932842, 1: 0.9379961, 2: 0.032675553},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.029367255, 1: 0.93795854, 2: 0.032674246},\n", " {0: 0.027969029, 1: 0.029201236, 2: 0.9428297},\n", " {0: 0.029367255, 1: 0.93795854, 2: 0.032674246},\n", " {0: 0.026586197, 1: 0.026445853, 2: 0.946968},\n", " {0: 0.02932842, 1: 0.9379961, 2: 0.032675553},\n", " {0: 0.026586197, 1: 0.026445853, 2: 0.946968},\n", " {0: 0.027941188, 1: 0.033275396, 2: 0.9387834},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816},\n", " {0: 0.02932842, 1: 0.9379961, 2: 0.032675553},\n", " {0: 0.94445217, 1: 0.025442092, 2: 0.030105816}]}"]}, "execution_count": 8, "metadata": {}, "output_type": "execute_result"}], "source": ["model_def = to_onnx(clr, X_train.astype(numpy.float32))\n", "oinf = OnnxInference(model_def)\n", "inputs = {'X': X_test.astype(numpy.float32)}\n", "outputs = oinf.run(inputs)\n", "outputs"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Let's extract the probability of the last class."]}, {"cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [{"data": {"text/plain": ["array([0.03010582, 0.03267555, 0.03267425, 0.03010582, 0.9438352 ,\n", " 0.0286698 , 0.9457275 , 0.03010582, 0.03010582, 0.9438352 ,\n", " 0.03267555, 0.03010582, 0.946968 , 0.0317053 , 0.03267555,\n", " 0.03010582, 0.03267555, 0.03267555, 0.03010582, 0.03010582,\n", " 0.03267555, 0.03267555, 0.9457739 , 0.03010582, 0.9116163 ,\n", " 0.03267555, 0.03010582, 0.03010582, 0.03267425, 0.9428297 ,\n", " 0.03267425, 0.946968 , 0.03267555, 0.946968 , 0.9387834 ,\n", " 0.03010582, 0.03267555, 0.03010582], dtype=float32)"]}, "execution_count": 9, "metadata": {}, "output_type": "execute_result"}], "source": ["def output_fct(res):\n", " val = res['output_probability'].values\n", " return val[:, 2]\n", "\n", "output_fct(outputs)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Let's compare both predictions."]}, {"cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [{"data": {"text/plain": ["array([1.35649712e-09, 1.35649712e-09, 1.35649712e-09, 1.35649712e-09,\n", " 1.35649712e-09, 1.35649712e-09, 1.35649712e-09, 1.35649712e-09,\n", " 1.35649712e-09, 1.35649712e-09, 1.40241483e-09, 1.40403427e-09,\n", " 1.40403427e-09, 1.40403427e-09, 4.08553857e-09, 7.87733068e-09,\n", " 8.05985446e-09, 8.05985446e-09, 8.05985446e-09, 8.05985446e-09,\n", " 8.05985446e-09, 8.05985446e-09, 8.05985446e-09, 8.05985446e-09,\n", " 8.05985446e-09, 8.05985446e-09, 8.05985446e-09, 8.05985446e-09,\n", " 8.05985446e-09, 9.19990018e-09, 9.34906490e-09, 1.80944041e-08,\n", " 2.73915506e-08, 2.81494498e-08, 2.81494498e-08, 6.50696940e-08,\n", " 6.50696940e-08, 6.50696940e-08])"]}, "execution_count": 10, "metadata": {}, "output_type": "execute_result"}], "source": ["diff = numpy.sort(numpy.abs(output_fct(outputs) - exp))\n", "diff"]}, {"cell_type": "markdown", "metadata": {}, "source": ["The highest difference is quite high but there is only one."]}, {"cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [{"data": {"text/plain": ["6.506969396635753e-08"]}, "execution_count": 11, "metadata": {}, "output_type": "execute_result"}], "source": ["max(diff)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## Why this difference?\n", "\n", "The function *astype_range* returns floats (single floats) around the true value of the orginal features in double floats. "]}, {"cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [{"data": {"text/plain": ["(array([[5.7999997 , 3.9999995 , 1.1999999 , 0.19999999],\n", " [5.0999994 , 2.4999998 , 2.9999998 , 1.0999999 ],\n", " [6.5999994 , 2.9999998 , 4.3999996 , 1.3999999 ],\n", " [5.3999996 , 3.8999996 , 1.2999998 , 0.39999998],\n", " [7.899999 , 3.7999995 , 6.3999996 , 1.9999998 ]], dtype=float32),\n", " array([[5.8000007 , 4.0000005 , 1.2000002 , 0.20000002],\n", " [5.1000004 , 2.5000002 , 3.0000002 , 1.1000001 ],\n", " [6.6000004 , 3.0000002 , 4.4000006 , 1.4000001 ],\n", " [5.4000006 , 3.9000006 , 1.3000001 , 0.40000004],\n", " [7.900001 , 3.8000004 , 6.4000006 , 2.0000002 ]], dtype=float32))"]}, "execution_count": 12, "metadata": {}, "output_type": "execute_result"}], "source": ["from mlprodict.onnx_tools.model_checker import astype_range\n", "astype_range(X_test[:5])"]}, {"cell_type": "markdown", "metadata": {}, "source": ["If a decision tree uses a threshold which verifies ``float32(t) != t``, it cannot be converted into single float without discrepencies. The interval ``[float32(t - |t|*1e-7), float32(t + |t|*1e-7)]`` is close to all double values converted to the same *float32* but every feature *x* in this interval verifies ``float32(x) >= float32(t)``. It is not an issue for continuous machine learned models as all errors usually compensate. For non continuous models, there might some outliers. Next function considers all intervals of input features and randomly chooses one extremity for each of them."]}, {"cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": ["from mlprodict.onnx_tools.model_checker import onnx_shaker"]}, {"cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [{"data": {"text/plain": ["(38, 100)"]}, "execution_count": 14, "metadata": {}, "output_type": "execute_result"}], "source": ["n = 100\n", "shaked = onnx_shaker(oinf, inputs, dtype=numpy.float32, n=n,\n", " output_fct=output_fct)\n", "shaked.shape"]}, {"cell_type": "markdown", "metadata": {}, "source": ["The function draws out 100 input vectors randomly choosing one extremity for each feature. It then sort every row. First column is the lower bound, last column is the upper bound."]}, {"cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [{"data": {"text/plain": ["array([0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0.02333647, 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. ], dtype=float32)"]}, "execution_count": 15, "metadata": {}, "output_type": "execute_result"}], "source": ["diff2 = shaked[:, n-1] - shaked[:, 0]\n", "diff2"]}, {"cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [{"data": {"text/plain": ["0.02333647"]}, "execution_count": 16, "metadata": {}, "output_type": "execute_result"}], "source": ["max(diff2)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["We get the same value as before. At least one feature of one observation is really close to one threshold and changes the prediction."]}, {"cell_type": "markdown", "metadata": {}, "source": ["## Bigger datasets"]}, {"cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [{"data": {"text/plain": ["GradientBoostingClassifier()"]}, "execution_count": 17, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.datasets import load_breast_cancer\n", "\n", "data = load_breast_cancer()\n", "X, y = data.data, data.target\n", "X_train, X_test, y_train, _ = train_test_split(\n", " X, y, random_state=1, shuffle=True)\n", "clr = GradientBoostingClassifier()\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": ["model_def = to_onnx(clr, X_train.astype(numpy.float32))\n", "oinf = OnnxInference(model_def)\n", "inputs = {'X': X_test.astype(numpy.float32)}"]}, {"cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [{"data": {"text/plain": ["(143, 100)"]}, "execution_count": 19, "metadata": {}, "output_type": "execute_result"}], "source": ["def output_fct1(res):\n", " val = res['output_probability'].values\n", " return val[:, 1]\n", "\n", "n = 100\n", "shaked = onnx_shaker(oinf, inputs, dtype=numpy.float32, n=n,\n", " output_fct=output_fct1, force=1)\n", "shaked.shape"]}, {"cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [{"data": {"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEXCAYAAABGeIg9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3de5xdZX3v8c93ZnIDwjUBIQQSNaLBWqQ5iKdqOUIloMdQemzDUaGIpWnB3rQKXlpqtV7aY70USKmlyEUjVdFYUeCgorVGCEVRxEi4mUCAAEISIITM/vWP59kza/bsmVlrJ2H2rPm+X6/9mr2u+7fWnr1/+7msZykiMDMzK+oZ7wDMzKz7ODmYmdkwTg5mZjaMk4OZmQ3j5GBmZsM4OZiZ2TBODpOEpPMkXT7ecVQh6fck/UeF9e+RdFx+/m5Jny4s+y1J6yRtkfRSSYdJukXSZkl/vCviN5AUkp4/3nFYdU4ONZG/SH8s6UlJD0i6UNLe4x3XeImIv42ItxZm/T1wdkTsERG3AO8Evh0RMyPik+MTpTVJmpcTSV8dXqcOnBxqQNLbgY8AfwHsBRwNHApcJ2nqsxhHN3/gDgVuG2W6tC4/TrOdwslhgpO0J/DXwNsi4hsR8UxE3AP8DukL8E2F1adL+nyuSvkvSb9a2M+7JN2Xl62RdGye3yPpHEl3SnpE0pWS9s3Lmr/CzpD0C+Cbkr4h6eyWGH8k6eT8/IWSrpP0aH6d3ymst5+klZI2SboReN4Yx/5mSffmuN7Tsuw8SZdLmiZpC9AL/CgfxzeB/wX8Y65mekFe7+8l/ULSg5KWS5qR93WMpPX5HD0A/GvJ83Ja3t/Dxfgk9eZqrzvz+b5Z0twS5+dEST/N29wn6R0jnJceSe/N5+YhSZdK2qtMbG329dpc/bYpV8udN8Z78heSNki6X9JbKuzrO/nvY/k9ebmk50n6Zj6/D0u6QoXScCf/s+1eZ7TjmdQiwo8J/AAWA9uBvjbLPgN8Lj8/D3gG+D/AFOAdwN35+WHAOuCgvO484Hn5+Z8Cq4CDgWnAPxX2OQ8I4FJgd2AGcCrwvUIMC4HH8ra759c5HegDjgQeBg7P664ArszrvRi4D/iPEY57IbAFeFXe98fyeTiucLyXF9YP4PmF6W8Dby1MfxxYCewLzAS+CnwoLzsm7/sj+bVmlDwv/5zX/VXgaeBFeflfAD/O5115+X4lzs8G4JX5+T7AkSOcm7cAa4HnAnsAXwIuKxNbm30dA/wK6YfkS4AHgZNG+V98ML93uwOfLZ730fZViKuvsL/nA7+Zz+9s0hf7x/OyHf2fHfZ58aPl/RzvAPzYwTcwlQweGGHZh4Hr8vPzgFWFZT3NL5v8IXwIOA6Y0rKP24FjC9MHkpJMX+GD9tzC8pnAE8ChefqDwMX5+e8C323Z/z8Bf0X6Zf8M8MLCsr9l5OTwl8CKwvTuwDY6SA6kL+gnml8ued7Lgbvz82PyvqdXPC8HF5bfCCzNz9cAS9oc04jnJz//BfAHwJ5j/E9cD/xRYfqwsrGV+H/7OPAPIyy7GPhwYfoFred9pH1R4ksbOAm4JT/f0f9ZJ4cxHq5WmvgeBmapfT34gXl507rmk4hoAOtJv7zWkn5tnQc8JGmFpIPyqocCV0l6TNJjpA9eP3DACPvdDHwNWJpnLQWuKOzrZc195f29EXgO6ZdhX3FfwL2jHPdBLa/7BPDIKOuPZjawG3BzIa5v5PlNGyNia2G6zHl5oPD8SdKveIC5wJ1t4hjt/AD8NnAicK+kG0apEjmIoefuXtK5LRPbEJJeJulbkjZKehxYBswa5XVHfP8q7gtJ++f/xfskbQIub66/E/5nbQxODhPf90nVAicXZ0raHTiB9CuyaW5heQ+p2H0/QER8NiJeQfpgBakKBdKH/YSI2LvwmB4R9xX22zq07+eAU/KX1wzgW4V93dCyrz0i4g+BjaSqm7mF/RwyynFvaDme3UhVM514GHiKVH3TjGuviCh+YbYeY5nzMpJ1tG9PGe38EBE3RcQSYH/gy6QquHbuJ72PTYeQzu2DJWJr9VlSddvciNgLWE4qabUz5D1h+Ps32r7aDQ/9oTz/JRGxJ6mUPPDaHf7PehjqkpwcJriIeJzUIP0pSYslTZE0D/g3UsngssLqvybp5FzK+FNSUlml1Of/1ZKmAVtJX5T9eZvlwAclHQogabakJWOEdTXpA/t+4PO5lALw78ALlBqSp+TH/5D0oojoJ9WNnydpN0kLgdNGeY0vAK+T9AqlHlnvp8P/5xzfPwP/IGn/fJxzJB0/ymadnJemTwN/I2mBkpdI2o9Rzo+kqZLeKGmviHgG2MTge9Tqc8CfSZovaQ9S9dznI2J7yfiKZgKPRsRWSUcB/3eUda8Efk/Swpys/6rCvjYCDVI7SXH9LaTG4zmkthoAduB/tt3rWBtODjUQER8F3k3qy78J+AHp19OxEfF0YdWvkOq1fwm8GTg5f9FMI7VPPEyqbtg/7w/gE6Rfe9dK2kxq6HvZGPE8TfqiP470a7E5fzPwGlJV0/35tZqNvABnk6o3HgAuAf51lNe4DTgr739DPqb1o8U1hneRGnFX5SqM/0+qqx9J5fNS8DHSF+m1pPfrX4AZJc7Pm4F7cnzLGNoTrehi0o+C75A6HWwF3lYytlZ/BLw/H+NfMnJphYj4Oqkd4Zukc/nNsvuKiCdJ7VPfy9VBR5N+9BwJPE6qqvxSYV8d/c+O8DrWhiJcyjIzs6FccjAzs2GcHMzMbBgnBzMzG8bJwczMhnFymCQ0wYdOlnSJpA/swPZbJNWu+2LxfVUaD+p94x1TO5L+UGnMqi1KY2hN6P/HycDJwSaFfDHZXeMdx64UEcsi4m929n61g8NcS5pC6r77mvw+dHole7t9D9zDI0/PkvQ9pQH3HpP0fUm/Xlh+mtJAh5uUBlP8aKfHVXdODlZrz/YH3180bR0ATKfDIdIr2kIaeHA2aXDCjwBfLbwvu5EuAJ1FuvbhWNIglNbCyWECk3S6pK8WptdKurIwvU7SEYVNjpN0h6RfSjpfkgrrvkXS7XnZNc2rS/OykLRspG1bYhptKOsLJX2hsO5HJF2frxJuDov9bqXhme+R9MZRjv338/E+qjTM90GFZSHpLEl3AHcU5jWrXy7Jx/A1peGefyDpeYXtX6M0BPTjki5QGsforcOCYGBo8C8oDQ++iXSF8EE5pkdzjL9fWH9I9VjzuAvT90h6h6Rb8+t/XtL0wvLRhsQe2HfhfL5dadjuDZJOL6y7n6Sv5l/QN0n6gEa+61674bRHHBa8JaYXkAYabG7femEckvbK22/M+3uv0vAuaJRhuyVdRhqi46s5rndGxNaIWJOvehfpqul9SKPtEhEXRsR3I2JbHk7jCuDXW2MyPCrrRH6QhgB4jJTkDyQNdHZfYdkvgZ48HaThGfYmfaA2AovzspNIV7S+iDRA23uB/yy8zojbtolptOGSdwN+DvweaTTYh8mjgzI4LPbH8na/QRop9bC8/BLgA/n5q/O2R+Z1PwV8pyXe60hfCDMK855f2NejwFH5eK8gj/BK+kW5iTRWVR/wJ6QRPd86wvGel5eflN+HGcANwAWkX8tH5PN1bOtxFI57fWH6HtIoqQfl+G8HluVlYw2JXTxHzfP5ftKw7CeSBtjbJy9fkR+7kYY/X8fII+DOY/hw2iMOC15y+2Lcl5Ku3p+Z1/05cEZeNuKw3YXzdVyb17yVNJJuAP88ymfoyxRGkvWjcG7GOwA/dvANTB/qI0lDLlyUv1heSLonwMrCegG8ojB9JXBOfv715ocxT/fkL5JDx9q2TTwjDpecp48ifTHfC5xSWK/5ZbZ7y+u8Lz8vfvH9C/DRwnp75NeYV4j31S1xtX6Jfrqw7ETgZ/n5qcD3C8uUz/FoyaGYmOaSfq3OLMz7EHBJ63EUjrs1ObypMP1RYHl+PuqQ2AxPDk8x9Av5IdJdApvDox9WWPYBqiWHEYcFL7l9kL74e0ljfC0sLPsD0i1c28UyMGx34XwNSw552XTgFOC0EZafThpyZdZ4foa79eFqpYnvBtIXwavy82+TfnX/Rp4uGmmY5kOBT2hwiONHSV+Kc0ps22rU4ZIj4kbgrrz/1nF6fhlp6O2me0m/oFsNGZI6IraQhusuxruudaMWIx1P61DgwdhjNhVf6yDS4HKbC/PubYltLKViY/QhzQEeiaGD7TX31W549LHOV6syw4KXMQuY2mZfc2D0YbvHEqmK6XPAOSrc9TDv9yTS2EwnRMTDbXcwyTk5THzN5PDK/PwGRk4OI1kH/EEMHeJ4RkT8ZwfxjDqUtaSzSFUE9wPvbNl2H6WhxpsOyeu1GjIkdd5mP9Kd45o6HTRsA6lKrLlvFadHUHyt+4F9Jc0szDukENsTpKqcpudQ3lhDYpfVHB69eFxzR1gX2p/LnTUs+MOkEkfrvprna9Rhu0eIrdUUCqOwSlpMGoX3f0fEjyvGO2k4OUx8N5DuhzwjItYD3yXVTe8H3FJyH8uBcyUdDgMNhG/oMJ4Rh0vOjZMfIH3A3wy8U0MbzAH+Wml46lcCryMNPd7qs8Dpko5QGrL5b4EfRLp39o76GvArkk5S6uFyFhW+wCNiHfCfwIckTZf0EuAMBm949EPgREn7SnoOqY2mrLGGxC4bY+vw6C8kVaeNpN0w1ztlWPAcy5Wk/5mZ+f/mz0klBBhl2O7sQYZ+8R+tPIy7pBmS3kUqzfwgL3816b347VyKtRE4OUxwEfFz0ofnu3l6E6na5nv5g1dmH1eRuvytyEX3n5BuFNSJtsMl5y/ay4GPRMSPIuIO0hDLl+UveEjVKb8k/Sq9gtQQ+7M28V4PvA/4IunX9PMYvPPcDslVDG8g1fU/QmqsXU2qFy/rFFI9+/3AVaTbfF6Xl10G/IhUV34t8PkKsY01JHYVZwN7kc75ZaQv+7bHGO2Hud6Zw4K/jVSiugv4D1LyvzgvG23Ybkgli/fmuN5BKpWeT3rv7iO1J702Ipol0Pfl474693DaIunrHcZdax6y27qCpGNI93weqwrnWZW7VK4H3hgR3xrveHYVSR8BnhMRo91gySYRlxzMWkg6XtLeuUTzblId96pxDmunkvRCpTvQSemubGeQSjlmQOpdYGZDvZxUtTEV+ClwUkQ8Nb4h7XQzSVVJB5G6uP4/0rUGZoCrlczMrA1XK5mZ2TC1qFaaNWtWzJs3b7zDMDObUG6++eaHI2J2u2W1SA7z5s1j9erV4x2GmdmEImnEq+xdrWRmZsM4OZiZ2TBODmZmNoyTg5mZDePkYGZmwzg5mJnZME4OZmY2jJNDjW3c/DTX3PbA2CuambVwcqixf7t5Hcsuv5lt2xvjHYqZTTBODjX2zPYgAhoeXNHMKnJyqLH+nBT6G04OZlaNk0ONNXJS6HfJwcwqcnKosWZSaLjkYGYVOTnUWDMpODeYWVVODjXWcJuDmXXIyaHG+nMPVvdWMrOqnBxqzCUHM+uUk0ONNZOCk4OZVeXkUGPNkoNrlcysKieHGhuoVnJ2MLOKnBxqzNVKZtYpJ4cac28lM+uUk0ONNZOCk4OZVeXkUGPuympmnXJyqLFmUmj4dg5mVpGTQ425t5KZdcrJocbcW8nMOuXkUGPNnBAuOZhZRU4ONdZwycHMOuTkUGP9bnMwsw45OdSYeyuZWaecHGrMF8GZWaecHGqsWWJwtZKZVVUqOUhaLGmNpLWSzmmzXJI+mZffKunIsbaV9HeSfpbXv0rS3oVl5+b110g6fkcPcrJqJoWGG6TNrKIxk4OkXuB84ARgIXCKpIUtq50ALMiPM4ELS2x7HfDiiHgJ8HPg3LzNQmApcDiwGLgg78cqcm8lM+tUmZLDUcDaiLgrIrYBK4AlLessAS6NZBWwt6QDR9s2Iq6NiO15+1XAwYV9rYiIpyPibmBt3o9VNFBycG4ws4rKJIc5wLrC9Po8r8w6ZbYFeAvw9Qqvh6QzJa2WtHrjxo0lDmPyaSYFN0ibWVVlkoPazGv9thlpnTG3lfQeYDtwRYXXIyIuiohFEbFo9uzZbTYxVyuZWaf6SqyzHphbmD4YuL/kOlNH21bSacDrgGNjcIyHMq9nJQxc5+CSg5lVVKbkcBOwQNJ8SVNJjcUrW9ZZCZyaey0dDTweERtG21bSYuBdwOsj4smWfS2VNE3SfFIj9407cIyTlu/nYGadGrPkEBHbJZ0NXAP0AhdHxG2SluXly4GrgRNJjcdPAqePtm3e9T8C04DrJAGsiohled9XAj8lVTedFRH9O+2IJ5HBksM4B2JmE06ZaiUi4mpSAijOW154HsBZZbfN858/yut9EPhgmdhsZA1f52BmHfIV0jXWzAm+QtrMqnJyqDHf7MfMOuXkUGPNpOCb/ZhZVU4ONRburWRmHXJyqLHBm/2McyBmNuE4OdRYfx6y272VzKwqJ4caa/g2oWbWISeHGvPwGWbWKSeHGvNFcGbWKSeHGhsclXWcAzGzCcfJocb63eZgZh1ycqixRi4x+CI4M6vKyaHGPGS3mXXKyaHGXK1kZp1ycqipiKCZE9xbycyqcnKoqWJVknODmVXl5FBTxaoktzmYWVVODjVVbGbwFdJmVpWTQ00VSwsuOZhZVU4ONVWsVnLJwcyqcnKoqWIPpYaHzzCzipwcaqpYk+TrHMysKieHmhrSldVtDmZWkZNDTRXbGVxyMLOqnBxqyhfBmdmOcHKoKVcrmdmOcHKoqWJNkq9zMLOqnBxqqt9tDma2A5wcasrVSma2I5wcaqrhK6TNbAc4OdTU0K6s4xiImU1ITg415WolM9sRTg41VRxPyb2VzKyqUslB0mJJayStlXROm+WS9Mm8/FZJR461raQ3SLpNUkPSosL8eZKekvTD/Fi+owc5GXlUVjPbEX1jrSCpFzgf+E1gPXCTpJUR8dPCaicAC/LjZcCFwMvG2PYnwMnAP7V52Tsj4ojOD8uGXiHt5GBm1ZQpORwFrI2IuyJiG7ACWNKyzhLg0khWAXtLOnC0bSPi9ohYs9OOxIYI3ybUzHZAmeQwB1hXmF6f55VZp8y27cyXdIukGyS9ssT61qKZEKb0yr2VzKyyMauVALWZ1/p1M9I6ZbZttQE4JCIekfRrwJclHR4Rm4a8oHQmcCbAIYccMsYuJ59mm8OU3p4hpQgzszLKlBzWA3ML0wcD95dcp8y2Q0TE0xHxSH5+M3An8II2610UEYsiYtHs2bNLHMbk0uyt1NcjVyuZWWVlksNNwAJJ8yVNBZYCK1vWWQmcmnstHQ08HhEbSm47hKTZuSEbSc8lNXLfVemobKARekpvj5ODmVU2ZrVSRGyXdDZwDdALXBwRt0lalpcvB64GTgTWAk8Cp4+2LYCk3wI+BcwGvibphxFxPPAq4P2StgP9wLKIeHRnHvRkUKxWcm8lM6uqTJsDEXE1KQEU5y0vPA/grLLb5vlXAVe1mf9F4Itl4rKRNa+KntLnaiUzq85XSNfUYG+lHlxwMLOqnBxqaqDNoafH93Mws8qcHGqqWZPU1+tqJTOrzsmhporVSh6V1cyqcnKoqWa10tTeHpwbzKwqJ4eaapYc+nrlNgczq8zJoaYG2xxcrWRm1Tk51FQzIUx1ycHMOuDkUFP9Hj7DzHaAk0NNDbY5+CI4M6vOyaGmYuAiOF/nYGbVOTnUVPE6B7c5mFlVTg411bz725Q+ubeSmVXm5FBTjYaH7Dazzjk51NTQ+zngW4WaWSVODjXVLC309ShPj2c0ZjbRODnUVLFaCXCPJTOrxMmhpvob6e/UvvQWu93BzKpwcqipZptD70C1kpODmZXn5FBTEUGPoFcpObhaycyqcHKoqf5G0CPR0yw5NMY5IDObUJwcaqo/gp4e0avBaTOzspwcaqrRCHoltzmYWUecHGqqv5Eao6VmtZKTg5mV5+RQU40IpMHeSq5WMrMqnBxqqhFBb4/cW8nMOuLkUFP9uc2h2VvJBQczq8LJoaYazd5K+R12ycHMqnByqKlGI10A1yO3OZhZdU4ONdWfr5DucW8lM+uAk0NNNRrNaiWXHMysOieHmurPvZUGSw7jHJCZTShODjU10FspD5/hK6TNrAonh5qKYOhFcG5zMLMKSiUHSYslrZG0VtI5bZZL0ifz8lslHTnWtpLeIOk2SQ1Ji1r2d25ef42k43fkACer/kauVnKbg5l1YMzkIKkXOB84AVgInCJpYctqJwAL8uNM4MIS2/4EOBn4TsvrLQSWAocDi4EL8n6sgtRbafAK6XByMLMKypQcjgLWRsRdEbENWAEsaVlnCXBpJKuAvSUdONq2EXF7RKxp83pLgBUR8XRE3A2szfuxChqNoQ3S/W6QNrMKyiSHOcC6wvT6PK/MOmW27eT1kHSmpNWSVm/cuHGMXU4+zbGVenyFtJl1oExyUJt5rd80I61TZttOXo+IuCgiFkXEotmzZ4+xy8mnP0CFaiX3VjKzKvpKrLMemFuYPhi4v+Q6U0ts28nr2RjSzX7wzX7MrCNlSg43AQskzZc0ldRYvLJlnZXAqbnX0tHA4xGxoeS2rVYCSyVNkzSf1Mh9Y4VjMtr0VnK1kplVMGbJISK2SzobuAboBS6OiNskLcvLlwNXAyeSGo+fBE4fbVsASb8FfAqYDXxN0g8j4vi87yuBnwLbgbMion+nHvUk0Oyt1ONqJTPrQJlqJSLialICKM5bXngewFllt83zrwKuGmGbDwIfLBObtRcR9Pb0FG72M84BmdmE4iuka2qwWmlw2sysLCeHmuoPhozK6ovgzKwKJ4eaavZW8s1+zKwTTg411YjWK6SdHMysPCeHmupvRLoIztc5mFkHnBxqqhHpfg69vtmPmXXAyaGmmr2Vcm5wm4OZVeLkUFONlt5KDbc5mFkFTg411Yigp3gnOJcczKwCJ4eaGryHtEsOZladk0NNNRoxtFrJucHMKnByqKn+aJYc8rSzg5lV4ORQU80G6R5f52BmHXByqKlGIzdI+wppM+uAk0NN9efhM9zmYGadcHKoqf5GutlP8yI4VyuZWRVODjXVyFdIu1rJzDrh5FBTjWDoRXBODmZWgZNDTfVHus5BuWrJN/sxsyqcHGqqka+QhnTDHw+fYWZVODnUVLO3EqTurP0estvMKnByqKGIIGLwFqE9Pe6tZGbVODnUULPtuZkcUsnBycHMynNyqKFmIujN725Pj1xyMLNKnBxqqJkImuMq9UgestvMKnFyqKGBkkOzWqnHvZXMrBonhxoaKDkUu7K6t5KZVeDkUEONnAia1Uq9Pb4IzsyqcXKooWYVUm8edK/HvZXMrCInhxoa7K3kK6TNrDNODjUULb2VenvcW8nMqnFyqKH+lgbp3h75Zj9mVomTQw21dmXtEa5WMrNKSiUHSYslrZG0VtI5bZZL0ifz8lslHTnWtpL2lXSdpDvy333y/HmSnpL0w/xYvjMOdDJp7a3ki+DMrKoxk4OkXuB84ARgIXCKpIUtq50ALMiPM4ELS2x7DnB9RCwArs/TTXdGxBH5sazTg5usBnor5Xe3t8e9lcysmjIlh6OAtRFxV0RsA1YAS1rWWQJcGskqYG9JB46x7RLgM/n5Z4CTdvBYLGt3EZzHVjKzKsokhznAusL0+jyvzDqjbXtARGwAyH/3L6w3X9Itkm6Q9Mp2QUk6U9JqSas3btxY4jAmj2YVkhukzaxTZZKD2sxr/aoZaZ0y27baABwSES8F/hz4rKQ9h+0k4qKIWBQRi2bPnj3GLieXwWqlQoO0s4OZVVAmOawH5hamDwbuL7nOaNs+mKueyH8fAoiIpyPikfz8ZuBO4AVlDsaS/paSg4fsNrOqyiSHm4AFkuZLmgosBVa2rLMSODX3WjoaeDxXFY227UrgtPz8NOArAJJm54ZsJD2X1Mh9V8dHOAk188DQ24Q6OZhZeX1jrRAR2yWdDVwD9AIXR8Rtkpbl5cuBq4ETgbXAk8Dpo22bd/1h4EpJZwC/AN6Q578KeL+k7UA/sCwiHt0pRztJDJYc0rRLDmZW1ZjJASAiriYlgOK85YXnAZxVdts8/xHg2Dbzvwh8sUxc1l7/sJv9DF77YNZtrr3tAX7jsNlM6+sd71CswFdI11DDN/uxCeKujVs487Kbuea2B8c7FGvh5FBDbUdldZuDdaFfPrkt/X1i2zhHYq2cHGqomQdywSFf5+DkYN1n09btAGze+sw4R2KtnBxqqBEt1Uq+Qtq61OaB5LB9nCOxVk4ONdRarSTfQ9q6VLPEsMnJoes4OdRQa2+l3h48Kqt1pc2uVupaTg41FK3VSu6tZF2qmRRcrdR9nBxqqFmF5FFZrdu55NC9nBxqaOAK6fzu+mY/1q3cIN29nBxqqNEyKqurlaxbuVqpezk51NDwe0jLw2dYV/J1Dt3LyaGGmiUHqdBbySUH60LNEsMT2/p9FX+XcXKoodZqJQ+fYd2qWGLY4qqlruLkUEPN3kq9vtmPdblNTz3DblPTaKybXLXUVZwcaqjR0lvJN/uxbhQRbHl6OwftPQNwo3S3cXKooba9lZwcrMs8sa2fRlBIDi45dBMnhxoaGD6j0FvJtUrWbZrJYM7e0/O0Sw7dxMmhhgaqlTR4Jzhf52DdppkMDtorlxyedsmhmzg51FDrqKyuVrJu1Cw5uM2hOzk51FB/zgPurWTdrHkBnJNDd3JyqKHmqKwq9FZywcG6TTMZzNpjKlN7e9yVtcs4OdTQ8OEzcLWSdZ1mtdLM6VOYOb3PJYcu4+RQQ/2tV0jnvx6Z1bpJMxnsOaPPyaELOTnUUGtvpWYJwj2WrJts3voMvT1ixpReZk6f4uscuoyTQw01CwitJQdXLVk32bx1OzOn9yHJJYcu5ORQQwM3+0k5YaAE4YKDdZNmcgBycnDJoZs4OdRQIwJp6JDd4Gol6y6btz7DzGlTAHK1kksO3cTJoYb6GzHQzgCDJQdXK1k32fRUa8nByaGbODnUUH/EQDsDDLY9uLeSdZNNW59h5vTBksOWp7f7B0wXcXKooYjB9gYoJAdXK1kX2bx1O3vmkkPz75anXXroFk4ONdRarSR3ZbUutHnrM0OqlZrzrDs4OdRQf6OlWknNaqXxishsqOaNfpAHTPMAAAdGSURBVIrVSuDxlbqJk0MNNSIGqpLAvZWs+zRv9DO85ODk0C2cHGqoEe17K7lB2rpFcVyl4l9XK3WPUslB0mJJayStlXROm+WS9Mm8/FZJR461raR9JV0n6Y78d5/CsnPz+mskHb+jBznZ9DcG2xmgkBxccrAu0SwhuOTQvcZMDpJ6gfOBE4CFwCmSFrasdgKwID/OBC4sse05wPURsQC4Pk+Tly8FDgcWAxfk/VhJjUYMVCXBYG8ldxO0bjFYcnCDdLfqK7HOUcDaiLgLQNIKYAnw08I6S4BLI91IYJWkvSUdCMwbZdslwDF5+88A3wbeleeviIingbslrc0xfL/zw2zvZw9s4m2fvWVn73bcPbBpKzOnDb61zcbp0y+5iam9rkm08bWtv8GDm7YCsPduUwHYM1crfeL6O7j0+/eOW2wT0TGHzeY9r239vb7jyiSHOcC6wvR64GUl1pkzxrYHRMQGgIjYIGn/wr5WtdnXEJLOJJVSOOSQQ0ocxnDT+3pZcMAeHW3bzRYcsAdHzdt3YPro+fty8kvnsHV7/zhGZZb09fSw/8xpHDprd35lzl4ATJ/Sy9te/Xzu3LhlnKObeA7Yc/ou2W+Z5KA281rrJ0Zap8y2nbweEXERcBHAokWLOqovmTdrdy544691sumEsv+e0/nY7x4x3mGYjertrzlsvEOwgjJ1DOuBuYXpg4H7S64z2rYP5qon8t+HKryemZntQmWSw03AAknzJU0lNRavbFlnJXBq7rV0NPB4rjIabduVwGn5+WnAVwrzl0qaJmk+qZH7xg6Pz8zMOjBmtVJEbJd0NnAN0AtcHBG3SVqWly8HrgZOBNYCTwKnj7Zt3vWHgSslnQH8AnhD3uY2SVeSGq23A2dFhCvLzcyeRYoa9H1ftGhRrF69erzDMDObUCTdHBGL2i1zv0YzMxvGycHMzIZxcjAzs2GcHMzMbJhaNEhL2gjsyDX3s4CHd1I4u9JEiRMc667iWHeNyRrroRExu92CWiSHHSVp9Ugt9t1kosQJjnVXcay7hmMdztVKZmY2jJODmZkN4+SQXDTeAZQ0UeIEx7qrONZdw7G2cJuDmZkN45KDmZkN4+RgZmbDTOrkIGmxpDWS1ko6Z7zjKZI0V9K3JN0u6TZJf5Ln7yvpOkl35L/7jHeskO4XLukWSf+ep7syToB8G9svSPpZPr8v78Z4Jf1Zfu9/IulzkqZ3S5ySLpb0kKSfFOaNGJukc/PnbI2k47sg1r/L7/+tkq6StHe3xlpY9g5JIWnWsxHrpE0OknqB84ETgIXAKZJ2/o1YO7cdeHtEvAg4Gjgrx3cOcH1ELACuz9Pd4E+A2wvT3RonwCeAb0TEC4FfJcXdVfFKmgP8MbAoIl5MGvJ+Kd0T5yXA4pZ5bWPL/7dLgcPzNhfkz9+z5RKGx3od8OKIeAnwc+Bc6NpYkTQX+E3S7Q2a83ZprJM2OQBHAWsj4q6I2AasAJaMc0wDImJDRPxXfr6Z9AU2hxTjZ/JqnwFOGp8IB0k6GHgt8OnC7K6LE0DSnsCrgH8BiIhtEfEY3RlvHzBDUh+wG+mOiF0RZ0R8B3i0ZfZIsS0BVkTE0xFxN+m+L0c9K4HSPtaIuDYitufJVaQ7TnZlrNk/AO9k6C2Td2mskzk5zAHWFabX53ldR9I84KXAD4AD8l32yH/3H7/IBnyc9I/bKMzrxjgBngtsBP41V4N9WtLudFm8EXEf8PekX4obSHdXvJYui7PFSLF1+2ftLcDX8/Oui1XS64H7IuJHLYt2aayTOTmozbyu69craQ/gi8CfRsSm8Y6nlaTXAQ9FxM3jHUtJfcCRwIUR8VLgCbqryguAXF+/BJgPHATsLulN4xtVx7r2sybpPaQq3Cuas9qsNm6xStoNeA/wl+0Wt5m302KdzMlhPTC3MH0wqdjeNSRNISWGKyLiS3n2g5IOzMsPBB4ar/iyXwdeL+keUtXcqyVdTvfF2bQeWB8RP8jTXyAli26L9zjg7ojYGBHPAF8C/ifdF2fRSLF15WdN0mnA64A3xuAFX90W6/NIPxB+lD9jBwP/Jek57OJYJ3NyuAlYIGm+pKmkhp2V4xzTAEki1YvfHhEfKyxaCZyWn58GfOXZjq0oIs6NiIMjYh7pHH4zIt5El8XZFBEPAOskHZZnHUu6X3m3xfsL4GhJu+X/hWNJ7U7dFmfRSLGtBJZKmiZpPrAAuHEc4hsgaTHwLuD1EfFkYVFXxRoRP46I/SNiXv6MrQeOzP/HuzbWiJi0D+BEUk+FO4H3jHc8LbG9glREvBX4YX6cCOxH6glyR/6773jHWoj5GODf8/NujvMIYHU+t18G9unGeIG/Bn4G/AS4DJjWLXECnyO1hTxD+sI6Y7TYSFUjdwJrgBO6INa1pPr65mdrebfG2rL8HmDWsxGrh88wM7NhJnO1kpmZjcDJwczMhnFyMDOzYZwczMxsGCcHMzMbxsnBzMyGcXIwM7Nh/huLSB1mitb4DgAAAABJRU5ErkJggg==\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["import matplotlib.pyplot as plt\n", "plt.plot(shaked[:, n-1] - shaked[:, 0])\n", "plt.title(\"Observed differences on a dataset\\nwhen exploring rounding to float32\");"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## DecisionTreeRegressor\n", "\n", "This model is much simple than the previous one as it contains only one tree. We study it on the [Boston](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_boston.html#sklearn.datasets.load_boston) datasets."]}, {"cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": ["from sklearn.datasets import load_boston\n", "data = load_boston()\n", "X, y = data.data, data.target\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=2, random_state=2)"]}, {"cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [{"data": {"text/plain": ["DecisionTreeRegressor()"]}, "execution_count": 22, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.tree import DecisionTreeRegressor\n", "clr = DecisionTreeRegressor()\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": ["ypred = clr.predict(X_test)"]}, {"cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": ["model_onnx = to_onnx(clr, X_train.astype(numpy.float32))"]}, {"cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": ["oinf = OnnxInference(model_onnx)\n", "opred = oinf.run({'X': X_test.astype(numpy.float32)})['variable']"]}, {"cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [{"data": {"text/plain": ["array([1.52587891e-06, 1.52587891e-06, 1.52587891e-06, 1.52587891e-06,\n", " 1.52587891e-06])"]}, "execution_count": 26, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.sort(numpy.abs(ypred - opred))[-5:]"]}, {"cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [{"data": {"text/plain": ["4.680610146230323e-06"]}, "execution_count": 27, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.max(numpy.abs(ypred - opred) / ypred) * 100"]}, {"cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["highest relative error: 4.68e-06%\n"]}], "source": ["print(\"highest relative error: {0:1.3}%\".format((numpy.max(numpy.abs(ypred - opred) / ypred) * 100)))"]}, {"cell_type": "markdown", "metadata": {}, "source": ["The last difference is quite big. Let's reuse function *onnx_shaker*."]}, {"cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [{"data": {"text/plain": ["(127, 1000)"]}, "execution_count": 29, "metadata": {}, "output_type": "execute_result"}], "source": ["def output_fct_reg(res):\n", " val = res['variable']\n", " return val\n", "\n", "n = 1000\n", "shaked = onnx_shaker(oinf, {'X': X_test.astype(numpy.float32)},\n", " dtype=numpy.float32, n=n,\n", " output_fct=output_fct_reg, force=1)\n", "shaked.shape"]}, {"cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [{"data": {"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAEmCAYAAAB7zsvVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO2deZhkVXn/P9/qnoEZGFlHZNhGBVE0iGZEVNQJaERc4GfUYISgaFDjHlxwDUbc8jNGo0ZFNCggSAgq7hIU3NFB3BARxIEZGGBYhl2nu+rNH+fc7ts1Vd3VVcXUrfL7eZ56uu49d3nPXb791nvOeY8iAmOMMcNHbdAGGGOM6Q4LuDHGDCkWcGOMGVIs4MYYM6RYwI0xZkixgBtjzJBiAe8CSSdIOm3QdswHSS+Q9P15bL9a0pPy9zdLOrlU9v8krZF0p6RHSNpb0iWS7pD0qnvDfjOcSDpF0omDtmNUsYC3IIvdryTdLel6SR+TtO2g7RoUEfHuiHhxadX7gVdExNYRcQnwBuCCiFgSEf8xGCtHC0krJTXyP8k7JV0r6R19OO68/pFvTiRdIOnFc285HOfZHFjAm5B0HPA+4PXANsABwB7AeZIWbkY7xjfXubpgD+DSWZY7puL1HDTX5X+SWwMHAi+SdPigjTIVIiL8yR/gPsCdwHOb1m8N3Agck5dPAM4GPg/cAfwMeHhp+zcC1+ayy4GD8/oacDzwe+Bm4Cxg+1y2HAjgRcA1wHeBb5A83bItvwCelb8/GDgPuCWf57ml7XYAzgVuB34CvBP4/ix1Pwq4Otv1FmA18KRSfU8DtsjXJ4C7cj2+DdSBP+ayB+Xt3p/rcQPwcWBRPtZKYG2+RtcDp3Z4XY7Ox7sJeEvJ7jHgzXnfO4CLgd06uD6HAr/J+1wLvK7NdakBb83X5kbgs8A2ndjW4lhPAy7J92QNcMIs264E1jatOwt4c2n5scBPgdvy38eWyl4AXJXr9wfg+cBD8n2q53u1IW+7Ta7X+lzPtwK10nG+n+/nrflYT53F7keQ3oc7SO/HmcCJuWw74Cv5PLfm77vmsnc1PUcfyes/lK/V7fnePr50rv2BVbnsBuADpbIDgB8CG0jvzMrZzjOsn4EbUKUPcAgwCYy3KPsMcEb+fgIwATwbWAC8Lj/YC4C98wO3LG+7HHhg/v4a4MfAriSR+0TpmIUYfBbYClgE/D3wg5IN++QHcou8zRrghcA48EiSgDw0b3sm6YXfCngYSaRaCng+7p3AE/KxP5CvwwwBL20fwJ6l5QuAF5eWP0j657E9sAT4MvCeXLYyH/t9+VyLOrwun8zbPhz4E/CQXP564Ff5uiuX79DB9VlHFgOSsDyyzbU5BrgSeADpH/k5wKmd2NbiWCuBvyD9U9iXJDqHz7Lt2tLyXvkeHpSXtyeJ4FG5fs/Ly0Xdbwf2ztvuXKr3C5qfA9Iz96V8r5YDvwNeVNp+AvgH0j/LlwHXAWph80LSP4DXkt6FZ+d9CwHfAfgbYHE+138DX2z3HOV1R+b9xoHjSP/0t8xlPwKOyt+3Bg7I33chOQKH5mv95Ly8tN15hvUzcAOq9MkPy/Vtyt4LnJe/nwD8uFRWKwQB2JPkqT0JWNB0jMvI3nhe3jk/4OMlMXhAqXwJydPdIy+/C/h0/v63wPeajv8J4J/zizYBPLhU9u7mF7dU9nbgzNLyVsBGuhBwkojeRf6nldc9BvhD/r4yH3vLeV6XXUvlPwGOyN8vBw5rUae21yd/vwZ4CXCfOZ6J84F/LC3v3altHTxvHwT+vU3ZSqBB+od9ez7POcDCXH4U8JOmfX5EEtyt8n5/Q/7lU9rmBeXnID8rfwL2Ka17CalNo9j+ylLZ4mzL/VrY/ASaxJ3kBZ/Ypo77Abe2eo5muWa3kn/tkn6lvgPYsWmbN5L/yZbWfRM4utPzDMvHMfCZ3ATs2CYuu3MuL1hTfImIBikssCwiriR5lCcAN0o6U9KyvOkewBckbZC0gSRcdWCnNse9A/gqcERedQRweulYjy6OlY/3fOB+wFKSwEwdi+QZtWNZ03nvInks3bCU9JJfXLLrG3l9wfqI+GNpuZPrcn3p+90kjwtgN1L4pJnZrg8kcTsUuFrShZIe06Y+y5h57a4mXdtObJuBpEdL+o6k9ZJuA14K7NjmvJBi4NtGxH2AbYF7SL8EW9lV2LZLvn9/m4+/TtJXJT24zTl2ZNpznnGcVvWLiLvz11Z1XAZcG1klS8cCQNJiSZ+QdLWk20kCvK2ksTa2Iek4SZdJui3fw22YvmYvIoXsfivpp5KentfvATyn6d4fSHqHRwoL+Ex+RPJGnlVeKWkr4Kkkb6xgt1J5jfTz/zqAiPhcRBxIepCCFC6AJJJPzS9l8dkyIq4tHbf88AOcATwvC8wi4DulY13YdKytI+JlpBjjZNlGYPdZ6r2uqT6LST9bu+EmktA8tGTXNpEa4gqa69jJdWnHGuCBbda3uz5ExE8j4jDgvsAXSeGmVlxHuo8Fu5Ou7Q0d2NbM50ihpd0iYhtS24A62TEibsv7P6ONXYVt1+btvxkRTyaJ1m9JYR7Y9NrfRPpF0VzHTq59M+uAXSSV61R+7o4j/YJ5dP6n9IS8vth+hm2SHk/ypp8LbBcR25Li/QKIiCsi4nmke/g+4Oz8rq4heeDle79VRLy31XmGGQt4ifySvAP4sKRDJC2QtJwUq1tLanAr+EtJz8re+mtIwv/j3Cf6IElbkBpK7iF5k5Be2HdJ2gNA0lJJh81h1tdIL9e/AJ/P3j6kBqAHSToq27lA0qMkPSQi6qSf2ydkr2cfUkNbO84Gni7pwNzT5l/o8tnI9n0S+HdJ98313EXSU2bZrZvrUnAy8E5Jeymxr6QdmOX6SFoo6fmStomICVKIot7m+GcAr5V0f0lbk0JRn4+IyQ7tK7MEuCUi/ihpf+DvOt0xn/sIpnv7fC3X7+8kjUv6W1Jbxlck7STpmVnM/kRq3yjqdwOwa9GjKj8rZ5Gu/5J8D/6J1Gg9X35E+uf2qmzTs0gNjeX63wNskLQ9KdxX5gZSW0N5+0mSQzIu6e2kjgbFNTlS0tL8zG3Iq+vZ9mdIeoqkMUlbKnXL3LXNeYaXQcdwqvgh/TT7Nelhu4EUO92uVH4CM3uhXEJuBCM1Tv0kr7+FJCRFg2aN9HJcnst/D7w7ly0neQatGlA/lcse1bR+b1KIZT0p5PFtYL9ctjSfu9NeKEeT4sJte6GUtp2rEXNLktBdlc9/GfCqXLaSTXtXzOu6MDPmPkbqNfGHvO9Pme7Z0PL6kEIG3yDFU2/P+xzY5rrUSG0Ea/JxTiuehblsa3GsZ5NCCnfke/OR8nVt2nYlKQZ+Z/7cnOtSvu4Hknpm3Jb/HpjX7wxcmNdvyDbtk8sW5uPcAtyU122X67U+1/PtNPVCabJtxv1vKltBeh+KXiifZ7oRc1m25U5SQ+lLyteP1Fbyu3xf/iPf20/le7SONN5gNdPP5Wmk9qY7Sf/YDi/Z8eh8DW7J9foqsHur8wxab3r5KFfIGGPMkOEQijHGDCkWcGOMGVIs4MYYM6RYwI0xZkixgI8gOXtd225SKqWKHSYkXSpp5Rzb7J7r33ZwiDGjggV8BIk0YOUq2Pz5mCWFpLuyiN4s6fzcR7lnIuKhEXHBHNtck+vfrl93R2g6jeudSmld7yktP7+XYzed5wJJf8zHvUnSOZJGbsSguXewgJt7g4dHGnm5N3AK8BFJzYM2Kk3+J1Ckcr0GeEZpXZHOoF/pcF+Rz7MnaYj6+/twzBn0yc6hPf+oYgEfEiS9UNKXS8tXSjqrtLxG0n75e0jaU9KxpPwfb8ge3pdLh9xP0i+Vckx8XtKWbc77QEnfzt70TZJOV4eTW0TETRFxKimD3ZvyCEkkbSPpU5LWKU1UcGI55CHpH5TyX9wh6TeSHpnXl2cJ2l/SKkm3S7pB0gfy+uW5/uN5eZmkcyXdkq/ZP5TOc4KksyR9Np/rUkkr5rgPKyWtlfRGSdcD/yWpJul4Sb/P1+msPNKw2OcAST9Uysvxi3ZhoIjYQBrWv19p3wdLOi/bf7mk55bKdpD05XwNfpqv4/dL5SHp5ZKuAK7I654u6efZlh9K2re0/Rvz/bgjn+vg2a51Lntmvm4b8q+Jh5TKVudj/hK4yyJ+LzDokUT+dPYhDf3dQPqnuzNpRN+1pbJbmR49NzVSjuQBn9h0rNWk0ZnLSGlJLwNe2ua8e5LScW5BGt35XeCDs9i5ySg9UmrRSXIeaZJIfYKUNe++2ZaX5LLnkPJwPIqU82JPprMxrmZ6FF67VKLLmTm670LgP0mjQ/cjjcor8rOfQEp3cChp1N97KGWZbLpexXlXMr90uB2nNiXln/lf4Et5uZOUwWeSkoftk7f9ftO9OC/f40V5/xtJoxTHSKNvV2ebZ0uD3O5aP4iUefLJ+R6/gZR6d2Hpuv2clGdnUfN19acPujBoA/yZx81KL9gjSTkxTiIJ34PzC35uabtOBPzI0vK/Ah/v0IbDgUtmKW85zJqU0e75pCx+fyq/0KRc1t/J378JvLrNsVczLaTtUokuzzaMZ+GoA0tK5e8BTsnfTwD+t1S2D3DPHOddyfzS4XaS2vRu0rD3yIJXDPnuJGXw3qWyE9lUwA8qLX8MeGfT8S4HnsjsaZDbXeu3AWeVlmukf74rS9ftmEG/N6P8cQhluLiQJCBPyN8vIL18T8zL86HTFKj3VUqJe61SCtDTmD0FaqtjLCB577eQEnMtIKU5LVJ9foLkiUP79LDNtEslWmYZKXnUHaV1bVOlkq7Dlh381J9POtw9mDu16asiZSfcl5SXZNfSceeTMrj8vdW6PYDjmo63G3OnQW53rWektI2UVGoNM69vK5tMn7CADxeFgD8+f7+QuQW812Q378nH2DdSCtAj6TAFaonDSGGHn5Be6D+RvLki1ed9IuKhedt26WFnEO1TiZa5Dthe0pLSum5Tpc44fdPybOlw50ptWq7Tr0he9Eclic5SBu9aOsRuzcdssnUN8K6m4y2OiDPy+VumQZ7lWs9IaZtt3o2Z19fJlu5FLODDxYXAX5HCD2uB75GmgduBlAGuFb2mzlxCnj9R0i6kKcw6QtL2Sl3uPgq8LyJujoh1wLeAf5N0n9wA+EBJT8y7nQy8TtJfKrGncprZpmO3SyU6RUSsIc0I8x6llKL7krzJ0+kvs6XDnSu1aTOfIQnlM5lfyuAHk6bgm41PAi9VmlhCkraS9DSlNLJt0yDPcq3PAp4m6eD8K+s40j/nH87z+pkusYAPERHxO5KYfi8v305K2fqDaN/v+VPAPvkn8xe7OO07SHH320gpOc/pYJ9fSLqT1KD1YuC1EfH2Uvnfk9Ka/obU+Ho2OaQQEf9Nmjruc6SUpF8kNcI1cwhwaT7Ph0jTmP2xxXbPI8XFrwO+QJpS7bwO6jAfPkSaqOFbku4gNWg+Gqb+iRxGmni5SNf6etq8exGxkZRK9W059PPXpDaP60jhnqLxFOAVpBlqismhzyAJaEsiYhVpbsuPkK77laR0seRjvpfUSHo96Z/Im3NZy2sdEZeTfpF9OO/3DFJ3y41zXC/TJ5xO1pgRQdL7SHNVHj1oW8zmwR64MUNK7iO+bw6H7E8KD31h0HaZzYc71hszvCwhhU2WkboA/hvwpYFaZDYrDqEYY8yQ4hCKMcYMKRbwCpFzV+w5aDu6RT1mPtQcaXCHlfJ9lfRxSW8btE2tkPSynOvkTqU8K0P9PP45YAE3lSFKaXBHlYh4aUS8s9/HVVMSry72XwB8APjrfB9u7qNtM/LPS9pR0g+UEn9tkPQjSY8rlR8t6WKl5FlrJf2rE2G1xgJuBs7mfjktBi3ZiZTw69LNcK47gWNI6QC2I/Vt/3LpviwmDevfkdSf/mDgdZvBrqHDAn4vo3mkgc08SdIVkm6VVAypLrY9RinN6q2SvlkeoZi9r5e227fJprbpTyV9TNLZpW3fpzQpgzSdSvXNSqllV2uWyQ2U0sJeqZQK9VxN59Zol+q0HGo4Jdfhq0rpTS+S9MDS/n+tlPL0Nkn/KelCSS9uY8cJks6WdJpSPpcXaPY0szNCQUW9S8urJb1ObdLxSnq9Uqrc6yQd02TL1LFL1/M4STfmfV5Y2nbWdLFNfDf/3aAUAnlMvs9vlXR1Pv5nJW3T4vo8iJTUqtj/2y222Sbvvz4f762SarmsbcphSaeS0hd8Odv1hmIQUB7ZKdKozu3IA7Yi4mMR8b2I2JjTEZwOPK7ZJoOzEd7bH+afBvYrwLakh349cEguO5w0cu4hpO6fbwV+WDpP231b2DRb+tPFwO9II/QeTxpht2suW0nKv/GBvN8TSelE987lp5AzHwIH5X0fmbf9MPDdJnunUp2W1pWzKN4C7J/rezpwZi7bEbgdeFYuezUpM9+L29T3hFx+eL4Pi5g9zexUPUr1XltaXk2bdLykUYs3AA8jpYP9XIt6ndh0Pf+FlODrUFJCre1y+azpYpvquJxSGt287hjSM/MAUrKyc2jKjDjH/mW7P0vqorgkb/s74EW5bNaUw5SyOTad85ekzI4BfHKWd+iLwHsH/S5X8TNwA/4cPswvDeyBpeWzgOPz968XL0xeruWXfY+59m1hT9v0p3l5f5J4Xg08r7RdIThbNZ3nbfl7WZw+Bfxrabut8zmWl+w9qMmuZqE7uVR2KPDb/P3vgR+VyorET7MJePmfx1xpZqfqUap3s4C3TMcLfLosNqQsfrMJ+D3MFM0bgQPoIF1sUx2Xs6kAnw/8Y2l57/J97mD/IInzGGmI/j6lspcAF7SxZUbKYdoIeC7bkpTu4Og25S8E1tKUytaf9HEIZfMwnzSw7dK87gF8SNNpQG8hCddsqVFbpohl9vSnRMRPSDlWRBLoMrdGxF2l5atJnmgzzalG7yRNZDCfVKPt6rOsvG+kN30ts1M+VydpZueiI9soXYM23BwRky2O1Wm62NmYcQ/y93HyfZ4HO5Jy1zQfaxfoLeVwpHDKGcDxkh5eLpN0OCk/y1Mj4qZ52vxngQV889BNGthm1pBmrSmnAl0UEd1kfpst/SmSXk76OXwdaZaVMttpZtrW3fN2zTSnGt2KlDWxH6lG11FKo5pj/e2y+7U611xpZu8ihS0K7jdP28ppXXefx75lOk0XW9DqWs64B9mWSVKIZz7cRPLcm49VXK+5Ug53cp8XUMqaKekQUvbEZ0RKs2taYAHfPHSTBraZj5PmlXwoTDUqPadLe9qmP80NWieSXsKjSPNp7te0/zskLZT0eODpwH+3OMfngBdK2k8pRem7gYsiYnWXNpf5KvAXkg5X6rnwcuYhsjF3mtmfA4cqpcO9H6nNoFPOIjWS7iNpMWn2nHkT808Xux5oMDN18BnAayXdX9LWpHvw+SaPv1NbziI9M0vyc/NPJE8b5k45PCOlsdIcoQfmZ2iRpDeSfhVclMsPIt2Lv8m/Bk0bLOCbgeguDWzzMb5A6m51Zv6Z+mvgqV2a1DL9aRbD00i5u38REVeQUoqemkUYUujgVpJ3dzqp8e63Lew9nzTl1v+QvNIHktoAeib/nH4OKfZ8M6mBbxWzpFJtwWxpZk8FfkGK3X4L+Pw8bPs68EHg26QGxE16dMyDjtPFRsTdpDS8P8ihsQNI8fhTSY2KfyDl+X5ll7a8kvTL5Crg+6R/0J/OZXOlHH4P8NZs1+tIv+4+Srp315LaN54WEcUvubflen8t91y5U9LXu7R7pHEuFNMxSrOpnxYRc4UrNiu5O9ta4PkR8Z1B23NvIaeLNU3YAzdDidIMN9vmXwZvJsVcfzxgs/qKnC7WzIFHpJlh5TGkn/HFzD6HR8Q9gzWp7zhdrJkVh1CMMWZIcQjFGGOGlM0aQtlxxx1j+fLlm/OUxhgz9Fx88cU3RcTS5vWbVcCXL1/OqlWrNucpjTFm6JHUckSvQyjGGDOkWMCNMWZIsYAbY8yQYgE3xpghxQJujDFDigXcGGOGFAu4McYMKRZwY0zHrLvtHs6/bL7zQZh7Cwu4MaZjzrjoGl522s8GbYbJWMCNMR3zp3qDjfUGToJXDSzgxpiOaTSScDes35XAAm6M6Zh6o/hrBa8CFnBjTMc0Imb8NYPFAm6M6ZjC87YHXg0s4MaYjpnMwj1pAa8EFnBjTMdMNWJawCuBBdwY0zH1HPuuOwZeCSzgxpiOsQdeLSzgxpiOsQdeLSzgxpiOcS+UamEBN8Z0zFQ/8MaADTGABdwYMw+mPHCHUCqBBdwY0zHTIRS74FXAAm6M6ZhpAR+wIQawgBtj5kE9R07ciFkNLODGmI6ZTidrAa8CFnBjTMe4G2G1sIAbYzrGA3mqhQXcGNMxHkpfLeYUcEmflnSjpF+X1m0v6TxJV+S/2927ZhpjqsCUB24BrwSdeOCnAIc0rTseOD8i9gLOz8vGmBHHMfBqMaeAR8R3gVuaVh8GfCZ//wxweJ/tMsZUEI/ErBbdxsB3ioh1APnvfftnkjGmqtgDrxb3eiOmpGMlrZK0av369ff26Ywx9yKe1LhadCvgN0jaGSD/vbHdhhFxUkSsiIgVS5cu7fJ0xpgq4KH01aJbAT8XODp/Pxr4Un/MMcZUmYaH0leKTroRngH8CNhb0lpJLwLeCzxZ0hXAk/OyMWbEqXsofaUYn2uDiHhem6KD+2yLMabiuBGzWngkpjGmYyzg1cICbozpGI/ErBYWcGNMxzQ8kKdSWMCNMR1TDyezqhIWcGNMx3gofbWwgBtjOsbpZKuFBdwY0zFuxKwWFnBjTMcUwj1pAa8EFnBjTMd4JGa1sIAbYzoiIkq5UAZri0lYwI0xHVGOmtgDrwYWcGNMR5QbLt2IWQ0s4MaYjih73RbwamABN8Z0RFm0HUKpBhZwY0xH1O2BVw4LuDGmI+p1C3jVsIAbYzrCHnj1sIAbYzqinP/EyayqgQXcGNMRZdF2MqtqYAE3xnRE3R545bCAG2M6olEaPu+h9NXAAm6M6QiHUKqHBdwY0xEOoVQPC7gxpiOcC6V6WMCNMR1hAa8eFnBjTEfMSGblEEol6EnAJb1W0qWSfi3pDElb9sswY0y1mJHMyh54JehawCXtArwKWBERDwPGgCP6ZZgxplp4KH316DWEMg4skjQOLAau690kY0wVaTidbOXoWsAj4lrg/cA1wDrgtoj4VvN2ko6VtErSqvXr13dvqTFmoLgRs3r0EkLZDjgMuD+wDNhK0pHN20XESRGxIiJWLF26tHtLjTEDZUYIxfpdCXoJoTwJ+ENErI+ICeAc4LH9McsYUzVmeuAeS18FehHwa4ADJC2WJOBg4LL+mGWMqRqFgC8Yk0MoFaGXGPhFwNnAz4Bf5WOd1Ce7jDEVo2i4XDBWww54NRjvZeeI+Gfgn/tkizGmwhQZCBeM1TyQpyJ4JKYxpiOmQyg1h1AqggXcGNMRRQhl4ZjcD7wiWMCNMR0x5YGP2wOvChZwY0xHTHvgFvCqYAE3xnTEZB69s9AeeGWwgBtjOqJe6kboXijVwAJujOmIIpnVwrGa08lWBAu4MaYjpjzwcdkDrwgWcGNMRzQaHolZNSzgxpiO8ECe6mEBN8Z0RJFCdqEbMSuDBdwY0xFTjZjjbsSsChZwY0xHTJZ6oUxawCuBBdwY0xGNUi8Ue+DVwAJujOmIGY2YjoFXAgu4MaYj3AuleljAjTEd0YhAgrGa08lWBQv4iBIR/H79nYM2w4wQ9UYwJjEmz4lZFSzgI8rPrrmVg//tQq688Y5Bm2JGhHoEYzVRq4lGJCfBDBYL+Ihyy10TANx698SALTGjQr2eBHy8prRsL3zgWMBHlIk8A+3EpJNWmP5QjxxCKQTcHvjAsYCPKFMCbi/J9IlGI6jVRE3KywM2yFjAR5WJnLjCHrjpF0UMfKw2vWwGiwV8RJnMHvik3STTJ+oNqGnaA3cMfPBYwEeUInRSeOLG9EqjEYzVmIqBezj94LGAjyhF6KSIhRvTK/UIxms1N2JWiJ4EXNK2ks6W9FtJl0l6TL8MM71RhE4m7YGbPpEaMSk1YvrZGjTjPe7/IeAbEfFsSQuBxX2wyfSBInSy0R646ROTeSRm0Q/cKWUHT9cCLuk+wBOAFwBExEZgY3/MMr1ShE4mLeCmT9QjdyP0QJ7K0EsI5QHAeuC/JF0i6WRJWzVvJOlYSaskrVq/fn0PpzPzYaofuEMopk80SrlQACe0qgC9CPg48EjgYxHxCOAu4PjmjSLipIhYERErli5d2sPpzHwoYt8T7kZo+kS9UfQDtwdeFXoR8LXA2oi4KC+fTRJ0UwEKz9uNmKZfNCJSP/CaPfCq0LWAR8T1wBpJe+dVBwO/6YtVpmemQyj2wE1/mPLApwbyDNgg03MvlFcCp+ceKFcBL+zdJNMPim6EjoGbflEPZg6ldwhl4PQk4BHxc2BFn2wxfWTjZDES026S6Q+N7IHX3IhZGTwSc0SZHshjATf9YbLRSP3Ax9wPvCpYwEeU6V4ofslMf2g0mDES0yGUwWMBH1E2ekIH02em08k6hFIVLOAjynQ6Wb9kpj/UG6kb4Zg98MpgAR9RnAvF9JtGaVJjcDKrKmABH1GcC8X0m3ojGK95TswqYQEfUYrQiUdimn5RhFDciFkdLOAjSuGBO4Ri+kXDjZiVwwI+ojgXiuk3k3lW+ql84H62Bo4FfERxLhTTb4p0sh6JWR0s4CNK0XjpgTymXzT3A7dvMHgs4CPKdAjFb5npD41GGoU5lczKHvjAsYCPKA6hmH6T0sl6UuMqYQEfUdyN0PSbFEKpeUaeCmEBH1GKHCjuRmj6RaPJA3cIZfBYwEeUial0sn7JTH+oR57U2EPpK4MFfEQpGjEdAzf9ol5v6gduAR84FvARJCKm4pMWcNMvCg/ckxpXBwv4CFKeB9NekukXm05q7Gdr0FjAR5DC6x6vyR646RuNSCGUmnuhVAYL+AhSNFwuWjjGRD0I/9Q1faA5naxDKIPHAj6CFF0HFy8cAxxGMb0TETSCphl5BmyUsYCPIsWM9IsXjqdldyU0PVL4AGlGnmKdn6tBYwEfQQrB3nJB8sCLPuHGdEvhFJQbMe0YDB4L+AjSHELxzPSmVwofoCZPqVYlemqmX7EAAA0zSURBVBZwSWOSLpH0lX4YZHqn8IwcAzf9ohDrsRpIoiaPxKwC/fDAXw1c1ofjmD5RdB1clEMoG+2Bmx4pugwWeVDGarIHXgF6EnBJuwJPA07ujzmmH0y4F4rpM4W3XYRPapI98ArQqwf+QeANgF28ClEI9qJCwN3fy/RI4W0XeVDGavJAngrQtYBLejpwY0RcPMd2x0paJWnV+vXruz2dmQdFo+WiBakboVPKml4pvO1iFOaYHEKpAr144I8DnilpNXAmcJCk05o3ioiTImJFRKxYunRpD6cznVLMgzkVQnF3L9MjU42YOQZeqzmEUgW6FvCIeFNE7BoRy4EjgG9HxJF9s8x0zZQHXnQjtAdueqRwAmqlEIrbVgaP+4GPINMjMQsB94tmeqPR5IGP1eSRmBVgvB8HiYgLgAv6cSzTO4VgF90IJz0S0/RIvakXypjciFkF7IGPIFP9wB1CMX2i8LZrM3qhDNIiAxbwkWR6JGb6geUQiumVQqyLboS1mpNZVQEL+AiySS4Uu0qmRzYZiekQSiWwgI8gk00hFHcjNL0y1YhZm+5G6H7gg8cCPoJMjcRcYA/c9IfpRsy0POah9JXAAj6CbBpC8YtmemOyRTIr9wMfPBbwEaQ8Jya4G6HpneYQyphHYlYCC/gIMlFvIMEW404na/rDVAjF6WQrhQV8BJmoBwtqNRbmgKV/6ppeaU5mVXMvlEpgAR9BJuoNFoyJ8bH0snlKNdMrrdLJuh/44LGAjyCT9QbjY7Wpl23CnpLpkXqrdLJ+rgaOBXwEmWgEC8ZqSGLBmDyhg+mZ5mRWtdr0RMdmcFjAR5CJyRRCAVgwVnM/cNMzxSNU7oXiRszBYwEfQSYbMRX/Hq/J/cBNz9Szu130A6/J/cCrgAV8BNlYb7Ag90CxB276QbMHPu5+4JXAAj6CTNYbLKhNC7hzoZhemZpSrRhK70mNK4EFfASZrAcLxrOnNCYm3NpkeqTRNJS+JncjrAIW8BFkY73BePbAF47VHAM3PVN428VzZQ+8GljAR5DJekz1Qhl3N0LTB+pTM/KQ/7oXShWwgI8gE6VGzPGaGzFN7zRazInpRszBYwEfQSYawXjRC2XcIRTTO/UWs9LbAx88FvARZLLeYGExkKcmp5M1PdM8lL4mUbdjMHAs4CPIRKkRc8FYjYlJv2imN5rTyY7bA68EFvARZLJeGonpboSmD2zigdeEm1YGjwV8BNlYb0zlAvdITNMPGpukk8X9wCuABXwEKXvgKRuhXzTTG5sks3I62UrQtYBL2k3SdyRdJulSSa/up2GmeyYbpW6E9sBNHyi87amRmM6FUgnGe9h3EjguIn4maQlwsaTzIuI3fbLNdMnGyWkB90hM0w/qLfqBuxFz8HTtgUfEuoj4Wf5+B3AZsEu/DDPdM9mIqVjleM0jMU3vTDVipsfKQ+krQl9i4JKWA48ALmpRdqykVZJWrV+/vh+nM3MwUW+wYHw6hLLRHrjpkXojqAmkci8UP1eDpmcBl7Q18D/AayLi9ubyiDgpIlZExIqlS5f2ejozBxGRZ6VPL9rCMQ/kMb1Tj5gKn4D7gVeFngRc0gKSeJ8eEef0xyTTC8UsKTMaMT0rvemRRiOmGjAhNWZGJIfBDI5eeqEI+BRwWUR8oH8mmV4ougyOl/uB+6eu6ZF6qV0FphszHUYZLL144I8DjgIOkvTz/Dm0T3aZLilGXS6Y0Q/cHrjpjXrE1ChMKAm4PfCB0nU3woj4PqA5NzSblSJcUk4n24jkKZVjmMbMh0bT81OEU9y8Mlg8EnPEKGLgUyMx89RqHsxjeqEeMZXICqbnxrQHPlgs4CPGxiYPvJjc2AJueqHemBlCKTxwp5QdLBbwEWO6F4pm/HU+FNML9UazB+4YeBWwgI8YRYNluRsh4JSypifqDTbpB57WW8AHiQV8xNiYBXx6QociBu4XzXRPo2kgTxFOcUrZwWIBHzGKUMl0CKWW19sDN93T3IupCKfYAx8sFvARY6JdCMUCbnqgHkG5F2rNIZRKYAEfMSbqM7sRLnQIxfSB5n7ghQfuEMpgsYCPGEXiqoWlgTzgXiimN+pNuVA8lL4aWMBHjCJUMj4VQkkv2kaHUEwPNMfAHUKpBhbwEWMqhDKVTtaNmKZ3mtPJTjViOoQyUCzgI0bhgS8cb27E9ItmuschlGpiAR8xJps88CKE4oE8phca0TqdrB+rwWIBHzGauxFOh1DsKZnuac6F4mRW1cACPmJMTA3kmdmI6X7gphcaDWbkQql5IE8lsICPGJObTOjggTymdzZpxPRQ+kpgAR8xinSy45ukk/WLZrpnkxCKPfBKYAEfMZrTyY5PpZO1B266J6WTnV52P/BqYAEfMZqnVHMIxfSDTZJZWcArgQV8xChmoC+6fDmdrOkHzelkPaFDNbCAjxiT9QYLxoTUlE7WHXZND7RLJ9uwBz5QLOAjxkS9MZXACsrdCP2ime5J6WQdQqkaFvARY6IeU6INntTY9IfmdLI1p5OtBBbwEWOi3pgafQmpt8BYTRZw0xP1aDOpsR+rgWIBHzEmmzxwSA2aHkpveqHRwEPpK4gFfMSYaDSmGi4LFo7VHAM3PTHZaLQZSm8XfJD0JOCSDpF0uaQrJR3fL6NM90zUYxMBHx9zCMX0Rn0TD9whlCrQtYBLGgM+CjwV2Ad4nqR9+mWY6Y7JemNG2k9Iw+rdjdD0Qvt0sv5lN0jGe9h3f+DKiLgKQNKZwGHAb/phWJkPn38F5/7iun4fdiS5bsM97L7DVjPWLRyr8ZVfrmPV6lsHZJUZdjbcvbHlQJ73f+tyPvm9qwZl1lDx7mf9BY9avn1fj9mLgO8CrCktrwUe3byRpGOBYwF23333rk60dMkW7LXT1l3t++fGXjttzV/tfd8Z617yxAfw46tuHpBFZhR40E5LeMbDl00t77RkS17w2OXceMcfB2jVcLFowVjfj6noshVZ0nOAp0TEi/PyUcD+EfHKdvusWLEiVq1a1dX5jDHmzxVJF0fEiub1vTRirgV2Ky3vCjjOYYwxm4leBPynwF6S7i9pIXAEcG5/zDLGGDMXXcfAI2JS0iuAbwJjwKcj4tK+WWaMMWZWemnEJCK+BnytT7YYY4yZBx6JaYwxQ4oF3BhjhhQLuDHGDCkWcGOMGVK6HsjT1cmk9cDVXe6+I3BTH83Z3Ay7/TD8dbD9g2fY6zAo+/eIiKXNKzergPeCpFWtRiINC8NuPwx/HWz/4Bn2OlTNfodQjDFmSLGAG2PMkDJMAn7SoA3okWG3H4a/DrZ/8Ax7HSpl/9DEwI0xxsxkmDxwY4wxJSzgxhgzpAyFgA/b5MmSdpP0HUmXSbpU0qvz+u0lnSfpivx3u0HbOhuSxiRdIukreXlo7Je0raSzJf0234fHDJP9AJJem5+fX0s6Q9KWVa6DpE9LulHSr0vr2tor6U35nb5c0lMGY/VM2tTh/+fn6JeSviBp21LZQOtQeQEf0smTJ4HjIuIhwAHAy7PNxwPnR8RewPl5ucq8GristDxM9n8I+EZEPBh4OKkeQ2O/pF2AVwErIuJhpJTNR1DtOpwCHNK0rqW9+X04Anho3uc/87s+aE5h0zqcBzwsIvYFfge8CapRh8oLOKXJkyNiI1BMnlxZImJdRPwsf7+DJB67kOz+TN7sM8Dhg7FwbiTtCjwNOLm0eijsl3Qf4AnApwAiYmNEbGBI7C8xDiySNA4sJs14Vdk6RMR3gVuaVrez9zDgzIj4U0T8AbiS9K4PlFZ1iIhvRcRkXvwxafYxqEAdhkHAW02evMuAbJk3kpYDjwAuAnaKiHWQRB64b/s9B84HgTcAjdK6YbH/AcB64L9yCOhkSVsxPPYTEdcC7weuAdYBt0XEtxiiOmTa2Tus7/UxwNfz94HXYRgEXC3WDUXfR0lbA/8DvCYibh+0PZ0i6enAjRFx8aBt6ZJx4JHAxyLiEcBdVCvUMCc5VnwYcH9gGbCVpCMHa1VfGbr3WtJbSOHR04tVLTbbrHUYBgEfysmTJS0giffpEXFOXn2DpJ1z+c7AjYOybw4eBzxT0mpSyOogSacxPPavBdZGxEV5+WySoA+L/QBPAv4QEesjYgI4B3gsw1UHaG/vUL3Xko4Gng48P6YHzwy8DsMg4EM3ebIkkeKvl0XEB0pF5wJH5+9HA1/a3LZ1QkS8KSJ2jYjlpOv97Yg4kuGx/3pgjaS986qDgd8wJPZnrgEOkLQ4P08Hk9pShqkO0N7ec4EjJG0h6f7AXsBPBmDfnEg6BHgj8MyIuLtUNPg6RETlP8ChpNbf3wNvGbQ9Hdh7IOmn1C+Bn+fPocAOpJb4K/Lf7Qdtawd1WQl8JX8fGvuB/YBV+R58EdhumOzPdXgH8Fvg18CpwBZVrgNwBileP0HyTl80m73AW/I7fTnw1EHbP0sdriTFuot3+eNVqYOH0htjzJAyDCEUY4wxLbCAG2PMkGIBN8aYIcUCbowxQ4oF3BhjhhQLuDHGDCkWcGOMGVL+D73nz6lhMpt8AAAAAElFTkSuQmCC\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["plt.plot(shaked[:, n-1] - shaked[:, 0])\n", "plt.title(\"Observed differences on a Boston dataset\\nwith a DecisionTreeRegressor\"\n", " \"\\nwhen exploring rounding to float32\");"]}, {"cell_type": "markdown", "metadata": {}, "source": ["That's consistent. This function is way to retrieve the error due to the conversion into float32 without using the expected values."]}, {"cell_type": "markdown", "metadata": {}, "source": ["## Runtime supporting float64 for DecisionTreeRegressor\n", "\n", "We prooved that the conversion to float32 introduces discrepencies in a statistical way. But if the runtime supports float64 and not only float32, we should have absolutely no discrepencies. Let's verify that error disappear when the runtime supports an operator handling float64, which is the case for the python runtime for *DecisionTreeRegression*."]}, {"cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": ["model_onnx64 = to_onnx(clr, X_train, rewrite_ops=True)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["The option **rewrite_ops** is needed to tell the function the operator we need is not (yet) supported by the official specification of ONNX. [TreeEnsembleRegressor](https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md#ai.onnx.ml.TreeEnsembleRegressor) only allows float coefficients and we need double coefficients. That's why the function rewrites the converter of this operator and selects the appropriate runtime operator **RuntimeTreeEnsembleRegressorDouble**. It works as if the ONNX specification was extended to support operator *TreeEnsembleRegressorDouble* which behaves the same but with double."]}, {"cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": ["oinf64 = OnnxInference(model_onnx64)\n", "opred64 = oinf64.run({'X': X_test})['variable']"]}, {"cell_type": "markdown", "metadata": {}, "source": ["The runtime operator is accessible with the following path:"]}, {"cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [{"data": {"text/plain": [""]}, "execution_count": 33, "metadata": {}, "output_type": "execute_result"}], "source": ["oinf64.sequence_[0].ops_"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Different from this one:"]}, {"cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [{"data": {"text/plain": [""]}, "execution_count": 34, "metadata": {}, "output_type": "execute_result"}], "source": ["oinf.sequence_[0].ops_"]}, {"cell_type": "markdown", "metadata": {}, "source": ["And the highest absolute difference is now null."]}, {"cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [{"data": {"text/plain": ["0.0"]}, "execution_count": 35, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.max(numpy.abs(ypred - opred64))"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## Interpretation\n", "\n", "We may wonder if we should extend the ONNX specifications to support double for every operator. However, the fact the model predict a very different value for an observation indicates the prediction cannot be trusted as a very small modification of the input introduces a huge change on the output. I would use a different model. We may also wonder which prediction is the best one compare to the expected value..."]}, {"cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [{"data": {"text/plain": ["26"]}, "execution_count": 36, "metadata": {}, "output_type": "execute_result"}], "source": ["i = numpy.argmax(numpy.abs(ypred - opred))\n", "i"]}, {"cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [{"data": {"text/plain": ["(50.0, 43.1, 43.1, 43.1)"]}, "execution_count": 37, "metadata": {}, "output_type": "execute_result"}], "source": ["y_test[i], ypred[i], opred[i], opred64[i]"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Well at the end, it is only luck on that kind of example."]}, {"cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2"}}, "nbformat": 4, "nbformat_minor": 2}