Coverage for mlprodict/onnx_tools/onnx_grammar/onnx_translation.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief One class which visits a syntax tree.
4"""
5import inspect
6import ast
7from textwrap import dedent
8import numpy
9from scipy.spatial.distance import squareform, pdist
10from .node_visitor_translator import CodeNodeVisitor
13def py_make_float_array(cst, op_version=None):
14 """
15 Creates an array with a single element
16 from a constant.
18 @param cst constant
19 @param op_version unused
20 @return array
22 .. runpython::
23 :showcode:
24 :warningout: DeprecationWarning
26 from mlprodict.onnx_tools.onnx_grammar.onnx_translation import py_make_float_array
27 print(py_make_float_array(5.5))
28 """
29 return numpy.array([cst], dtype=numpy.float32)
32def py_pow(x, p, op_version=None):
33 """
34 Function for python operator ``**``.
36 @param x float
37 @param p power
38 @param op_version unused
39 @return :math:`x^p`
40 """
41 return x ** p
44def py_mul(*x, op_version=None):
45 """
46 Function for python operator ``*``.
48 @param x floats
49 @param op_version unused
50 @return `x*y`
51 """
52 if len(x) == 2:
53 return x[0] * x[1]
54 p = x[0]
55 for y in x[1:]:
56 p *= y
57 return p
60def py_opp(x, op_version=None):
61 """
62 Function for python unary operator ``-``.
64 @param x floats
65 @param op_version unused
66 @return `-x`
67 """
68 return -x
71def squareform_pdist(X, metric='sqeuclidean', op_version=None):
72 """
73 Replacements for `squareform
74 <http://scipy.github.io/devdocs/generated/scipy.spatial.distance.squareform.html>`_
75 and `pdist
76 <http://scipy.github.io/devdocs/generated/scipy.spatial.distance.pdist.html>`_.
77 """
78 return squareform(pdist(X, metric=metric))
81def get_default_context():
82 """
83 Returns a default context useful for most of the conversion
84 from a function using :epkg:`numpy` into :epkg:`ONNX`.
85 """
86 context = {'py_pow': py_pow, 'py_make_float_array': py_make_float_array,
87 'py_mul': py_mul, 'py_opp': py_opp,
88 'cdist': 'cdist', 'squareform_pdist': 'squareform_pdist'}
89 allow = set(('abs add ceil arccos arccosh arcsin arcsinh arctan arctanh ceil cos cosh divide'
90 'equal exp floor greater invert less log matmul maximum minimum mod'
91 'multiply power sign sin sinh sqrt square subtract tan tanh transpose').split())
92 for k, v in numpy.__dict__.items():
93 if k not in allow:
94 continue
95 context['numpy.%s' % k] = v
96 context['np.%s' % k] = v
97 return context
100def get_default_context_cpl():
101 """
102 Returns a default useful context to compile the converter
103 returned by @see fn translate_fct2onnx.
104 """
105 ctx = {'py_make_float_array': py_make_float_array,
106 'py_pow': py_pow, 'py_mul': py_mul, 'py_opp': py_opp,
107 'numpy': numpy}
108 try:
109 from skl2onnx.algebra.complex_functions import onnx_squareform_pdist # delayed
110 from skl2onnx.algebra.complex_functions import onnx_cdist # delayed
111 ctx['onnx_squareform_pdist'] = onnx_squareform_pdist
112 ctx['onnx_cdist'] = onnx_cdist
113 except ImportError: # pragma: no cover
114 # Too old version for skl2onnx.
115 pass
117 from skl2onnx.algebra import onnx_ops # delayed
118 from skl2onnx.algebra.onnx_operator import OnnxOperator # delayed
119 d = onnx_ops.__dict__
120 for k, v in d.items():
121 try:
122 if k.startswith("Onnx") and issubclass(v, OnnxOperator):
123 ctx[k] = v
124 except TypeError as e:
125 if inspect.isfunction(v):
126 continue
127 raise RuntimeError( # pragma: no cover
128 "Issue with {}={} (type={})".format(k, v, type(v))) from e
129 return ctx
132def translate_fct2onnx(fct, context=None, cpl=False,
133 context_cpl=None, output_names=None,
134 dtype=numpy.float32,
135 verbose=0, fLOG=None):
136 """
137 Translates a function into :epkg:`ONNX`. The code it produces
138 is using classes *OnnxAbs*, *OnnxAdd*, ...
140 @param fct function to convert
141 @param context context of the function to convert
142 something like ``{'numpy.transpose': numpy.transpose}``,
143 if *context* is None, it receives a default value
144 returnd by @see fn get_default_context
145 @param cpl compile the function after it was
146 created
147 @param context_cpl context used at compiling time
148 if *context_cpl* is None, it receives a default value
149 returnd by @see fn get_default_context_cpl
150 @param output_names names of the output in the :epkg:`ONNX` graph
151 @param dtype :epkg:`numpy` float type used to produce the model
152 @param verbose integer, display more information
153 @param fLOG logging function
154 @return code or compiled code
156 .. exref::
157 :title: Convert a function into ONNX code
159 The following code parses a python function and returns
160 another python function which produces an :epkg:`ONNX`
161 graph if executed.
163 .. runpython::
164 :showcode:
165 :warningout: DeprecationWarning
166 :process:
167 :store_in_file: fct2onnx2.py
169 import numpy
170 from mlprodict.onnx_tools.onnx_grammar import translate_fct2onnx
172 def trs(x, y):
173 z = x + numpy.transpose(y, axes=[1, 0])
174 return x * z
176 onnx_code = translate_fct2onnx(
177 trs, context={'numpy.transpose': numpy.transpose})
178 print(onnx_code)
180 Next example goes further and compile the outcome.
182 .. exref::
183 :title: Convert a function into ONNX code and run
185 The following code parses a python function and returns
186 another python function which produces an :epkg:`ONNX`
187 graph if executed. The example executes the function,
188 creates an :epkg:`ONNX` then uses @see cl OnnxInference
189 to compute *predictions*. Finally it compares
190 them to the original.
192 .. runpython::
193 :showcode:
194 :warningout: DeprecationWarning
195 :process:
196 :store_in_file: fct2onnx3.py
198 import numpy
199 from mlprodict.onnx_tools.onnx_grammar import translate_fct2onnx
200 from mlprodict.plotting.text_plot import onnx_simple_text_plot
201 from mlprodict.onnxrt import OnnxInference
202 from mlprodict.npy.xop import loadop
205 OnnxAdd, OnnxTranspose, OnnxMul, OnnxIdentity = loadop(
206 'Add', 'Transpose', 'Mul', 'Identity')
209 ctx = {'OnnxAdd': OnnxAdd,
210 'OnnxTranspose': OnnxTranspose,
211 'OnnxMul': OnnxMul,
212 'OnnxIdentity': OnnxIdentity}
214 def trs(x, y):
215 z = x + numpy.transpose(y, axes=[1, 0])
216 return x * z
218 inputs = {'x': numpy.array([[1, 2]], dtype=numpy.float32),
219 'y': numpy.array([[-0.3, 0.4]], dtype=numpy.float32).T}
221 original = trs(inputs['x'], inputs['y'])
223 print('original output:', original)
225 onnx_fct = translate_fct2onnx(
226 trs, context={'numpy.transpose': numpy.transpose},
227 cpl=True, context_cpl=ctx, output_names=['Z'])
229 onnx_code = onnx_fct('x', 'y', op_version=12)
231 onnx_g = onnx_code.to_onnx(inputs, target_opset=12)
232 print("ONNX model")
233 print(onnx_simple_text_plot(onnx_g))
235 oinf = OnnxInference(onnx_g)
236 res = oinf.run(inputs)
238 print('-----------')
239 print("ONNX inference:", res['Z'])
241 The function to be converted may include python functions
242 which must not be converted. In that case, their name
243 must be prefixed by ``py_``. The execution of the function
244 this one builds produces the following error::
246 TypeError: Parameter to MergeFrom() must be instance of same class:
247 expected onnx.TensorProto got onnx.AttributeProto.
249 It indicates that constants in the code marges multiple types,
250 usually floats and tensor of floats. Floats should be converted
251 using the following function::
253 def py_make_float_array(cst):
254 return numpy.array([cst], dtype=numpy.float32)
256 The function replaces empty contexts by default values which
257 covers many :epkg:`numpy` functions. The tutorial
258 :ref:`l-onnx-tutorial` gives an example of how it can be used
259 on a more complex function.
260 """
261 def compile_code(name, code, context=None):
262 """
263 Compiles a python function with the given
264 context.
266 @param name function name
267 @param code python code
268 @param context context used at compilation
269 @return compiled function
270 """
271 if context is None:
272 context = {} # pragma: no cover
273 try:
274 obj = compile(code, "", "exec")
275 except SyntaxError as e: # pragma: no cover
276 raise SyntaxError("Unable to compile\n{}".format(code)) from e
277 context_g = context.copy()
278 context_l = context.copy()
279 exec(obj, context_g, context_l) # pylint: disable=W0122
280 return context_l[name]
282 if isinstance(fct, str):
283 code = fct
284 elif callable(fct):
285 code = inspect.getsource(fct)
286 else:
287 raise TypeError( # pragma: no cover
288 "Unable to guess code from type {}.".format(type(fct)))
289 node = ast.parse(dedent(code))
290 v = CodeNodeVisitor()
291 v.visit(node)
292 if context is None:
293 context = get_default_context()
294 onnx_code = v.export(context=context,
295 output_names=output_names)
296 if not cpl:
297 return onnx_code
298 if verbose > 0 and fLOG is not None: # pragma: no cover
299 fLOG('[translate_fct2onnx] python code')
300 fLOG(code)
301 fLOG('[translate_fct2onnx] ONNX code')
302 fLOG(onnx_code)
303 if context_cpl is None:
304 context_cpl = get_default_context_cpl()
305 if 'numpy' not in context_cpl:
306 context_cpl = context_cpl.copy()
307 context_cpl['numpy'] = numpy
308 return compile_code(fct.__name__, onnx_code, context_cpl)