Coverage for mlprodict/onnx_tools/onnx_grammar/onnx_translator.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief One class which visits a syntax tree.
4"""
5import pprint
6import numpy
9class CodeTranslator:
10 """
11 Class which converts a Python function into
12 something else. It must implements
13 methods *visit* and *depart*.
14 """
16 def __init__(self, visitor):
17 """
18 :param visitor: :class:`CodeNodeVisitor
19 <mlprodict.onnx_tools.onnx_grammar.node_visitor_translator>`
20 """
21 self._visitor = visitor
23 def export(self, context=None, **kwargs):
24 """
25 Exports the parsed :epkg:`python` code
26 into something.
27 """
28 raise NotImplementedError( # pragma: no cover
29 "This function should be overwritten.")
31 def visit(self, node, info):
32 """
33 Visits a node.
35 @param node visited node
36 @param info info extracted by the visitor
37 """
38 raise NotImplementedError( # pragma: no cover
39 "This function should be overwritten.")
41 def depart(self, node, info):
42 """
43 Leaves a node.
45 @param node visited node
46 @param info info extracted by the visitor
47 """
48 raise NotImplementedError( # pragma: no cover
49 "This function should be overwritten.")
52class OnnxTranslator(CodeTranslator):
53 """
54 Class which converts a Python function into
55 an :epkg:`ONNX` function. It must implements
56 methods *visit* and *depart*.
57 """
58 _binary_operators = {
59 'Add': 'Add', 'Div': 'Div',
60 'Mult': 'Mul', 'Sub': 'Sub',
61 'Pow': 'Pow', 'MatMult': 'MatMul',
62 }
64 _unary_operators = {
65 'Sub': 'Neg',
66 }
68 _numpy2onnx_op = {
69 'absolute': 'Abs',
70 'cos': 'Cos',
71 'exp': 'Exp',
72 'power': 'Pow',
73 'transpose': 'Transpose',
74 'sin': 'Sin',
75 # complex function
76 'inner': 'inner',
77 }
79 _parameter_mapping = {
80 'Transpose': {'axes': 'perm'}
81 }
83 class Parameter:
84 """
85 Holds parameter information.
86 """
88 def __init__(self, name, value=('#NODEFAULT#', ), annotation=None):
89 """
90 @param name parameter name
91 @param value parameter value
92 """
93 self.name = name
94 self.value = value
95 self.annotation = annotation
97 @staticmethod
98 def format_value(value):
99 """
100 Returns a formatted value in python code.
101 """
102 if isinstance(value, str):
103 return '"{}"'.format(value.replace('"', '\\"').replace('\\', '\\\\'))
104 if isinstance(value, list):
105 return "[{}]".format(", ".join(map(OnnxTranslator.Parameter.format_value, value)))
106 if isinstance(value, tuple):
107 if value == ('#NODEFAULT#', ):
108 return None
109 return "({})".format(", ".join(map(OnnxTranslator.Parameter.format_value, value)))
110 return str(value)
112 @property
113 def formatted_value(self):
114 """
115 Returns a formatted value in python code.
116 """
117 return OnnxTranslator.Parameter.format_value(self.value)
119 def __str__(self):
120 """
121 Into python syntax.
122 """
123 rows = [self.name]
124 if self.value != ('#NODEFAULT#', ):
125 rows.append('=')
126 rows.append(self.formatted_value)
127 return ''.join(rows)
129 def __init__(self, visitor):
130 """
131 :param visitor: :class:`CodeNodeVisitor
132 <mlprodict.onnx_tools.onnx_grammar.node_visitor_translator>`
133 """
134 CodeTranslator.__init__(self, visitor)
135 self._stack = []
136 self._code_fct = None
138 def _is_stacked(self, name):
139 for line in self._stack:
140 if line[0] == name:
141 return True
142 return False
144 def _get_last(self, name, info=None):
145 if len(self._stack) == 0:
146 raise RuntimeError("Stack is empty.") # pragma: no cover
147 last = self._stack[-1]
148 if ((isinstance(name, str) and last[0] != name) or
149 (isinstance(name, tuple) and last[0] not in name)):
150 raise RuntimeError( # pragma: no cover
151 "Last item is not '{}'\n{}\n---\n{}".format(
152 name, pprint.pformat(self._stack),
153 pprint.pformat(info) if info else ""))
154 return last
156 def make_msg(self, info):
157 """
158 Make a message with line and column information.
159 """
160 lineno = '?'
161 col_offset = '?'
162 if isinstance(info, dict):
163 if 'node' in info:
164 node = info['node']
165 lineno = node.lineno
166 col_offset = node.col_offset
167 else:
168 if 'lineno' in info:
169 lineno = info['lineno']
170 if 'col_offset' in info:
171 col_offset = info['col_offset']
172 else:
173 if hasattr(info, 'lineno'):
174 lineno = info.lineno
175 if hasattr(info, 'col_offset'):
176 col_offset = info.col_offset
178 return "line {}, col {}".format(lineno, col_offset)
180 def export(self, context=None, format='code', # pylint: disable=W0221
181 output_names=None):
182 """
183 Returns an :epkg:`ONNX` graph or a piece
184 of code which could generate the graph.
186 @param context function used in the function code
187 @param format ``'code'``
188 @param output_names add code in the final function
189 to overwrite the names of the
190 outputs in the :epkg:`ONNX` graph
191 @return string or :epkg:`onnx` graph
193 This method is used in function @see fn translate_fct2onnx.
194 An example of code can be found there.
195 """
196 if self._code_fct is None:
197 raise RuntimeError( # pragma: no cover
198 "No python code was parsed.")
199 if context is None:
200 context = {}
202 def find_onnx_correspondance(fct, info):
203 if isinstance(fct, numpy.ufunc):
204 name = fct.__name__
205 elif callable(fct) and getattr(fct, '__module__', '') in (
206 'numpy', 'numpy.core.fromnumeric'):
207 name = fct.__name__
208 elif callable(fct) and fct.__name__.startswith("py_"):
209 return fct
210 else:
211 name = None
212 if name is not None and name not in OnnxTranslator._numpy2onnx_op:
213 raise RuntimeError( # pragma: no cover
214 "Unable to find a correspondance to '{}' at {} in \n{}".format(
215 name, self.make_msg(info),
216 "\n".join(sorted(OnnxTranslator._numpy2onnx_op))))
217 if name is not None:
218 return OnnxTranslator._numpy2onnx_op[name]
219 if isinstance(fct, str):
220 return fct
221 raise RuntimeError( # pragma: no cover
222 "Unable to find a correspondance for function name '{}' in module '{}', "
223 "'{}' (type {}) at {}.".format(
224 name, getattr(fct, '__module__', ''),
225 fct, type(fct), self.make_msg(info)))
227 def write_expression(stack_fct_used, expr, indent, parameter_mapping=None):
228 if isinstance(expr, str):
229 # an argument
230 return ['{}{}'.format(" " * indent * 4, expr)]
231 if isinstance(expr, (int, float)):
232 # an argument
233 return ['{}{}'.format(" " * indent * 4, expr)]
234 if isinstance(expr, OnnxTranslator.Parameter):
235 if parameter_mapping is None:
236 name = expr.name
237 else:
238 name = parameter_mapping.get(expr.name, expr.name)
239 return ["{}{}={}".format(" " * indent * 4, name,
240 expr.formatted_value)]
241 rows = []
242 if isinstance(expr, tuple):
243 expr = [expr]
244 for op, args in expr:
245 if op == 'BinOp':
246 opname = args["op"]
247 opon = args["args"]
248 onnx_name = OnnxTranslator._binary_operators[opname]
249 rows.append(
250 '{}Onnx{}('.format(" " * indent * 4, onnx_name))
251 for expr2 in opon:
252 sexpr2 = write_expression(
253 stack_fct_used, expr2, indent + 1)
254 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)):
255 continue # pragma: no cover
256 rows.extend(sexpr2)
257 rows[-1] += ","
258 rows.append('{}op_version=op_version'.format(
259 " " * (indent + 1) * 4))
260 rows.append('{})'.format(" " * indent * 4))
261 elif op == 'UnaryOp':
262 opname = args["op"]
263 opon = args["args"]
264 onnx_name = OnnxTranslator._unary_operators[opname]
265 rows.append(
266 '{}Onnx{}('.format(" " * indent * 4, onnx_name))
267 for expr2 in opon:
268 sexpr2 = write_expression(
269 stack_fct_used, expr2, indent + 1)
270 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)):
271 continue
272 rows.extend(sexpr2)
273 rows[-1] += ","
274 rows.append('{}op_version=op_version'.format(
275 " " * (indent + 1) * 4))
276 rows.append('{})'.format(" " * indent * 4))
277 elif op == 'Call':
278 name = args['name']
279 if name.startswith("onnx_"):
280 raise RuntimeError("The code must not use a function prefixed by 'onnx_' (%s). "
281 "It indicates that function manipulate ONNX node and "
282 "the fonction to convert must only deal with arrays." % name)
283 if name not in context:
284 raise RuntimeError(
285 "Unable to find function '{}' at {} in context\n{}\n--\n{}".format(
286 name, self.make_msg(args),
287 '\n'.join(sorted(context)),
288 pprint.pformat(args)))
289 op_conv = find_onnx_correspondance(context[name], args)
290 if callable(op_conv) and op_conv.__name__.startswith('py_'):
291 rows.append(
292 '{}{}('.format(" " * indent * 4, op_conv.__name__))
293 elif callable(op_conv) and op_conv.__name__.startswith('onnx_'):
294 stack_fct_used.append(op_conv.__name__)
295 rows.append(
296 '{}{}('.format(" " * indent * 4, op_conv))
297 else:
298 prefix = "onnx_" if 'a' <= op_conv[0] <= 'z' else 'Onnx'
299 if prefix == "onnx_":
300 stack_fct_used.append(
301 "{}{}".format(prefix, op_conv))
302 prefix = '_' + prefix
303 rows.append(
304 '{}{}{}('.format(" " * indent * 4, prefix, op_conv))
306 opon = args["args"]
307 opon = opon[1:]
308 for expr2 in opon:
309 sexpr2 = write_expression(
310 stack_fct_used, expr2, indent + 1,
311 OnnxTranslator._parameter_mapping.get(op_conv, None))
312 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)):
313 continue
314 rows.extend(sexpr2)
315 rows[-1] += ","
316 rows.append('{}op_version=op_version'.format(
317 " " * (indent + 1) * 4))
318 rows.append('{})'.format(" " * indent * 4))
319 else:
320 raise RuntimeError( # pragma: no cover
321 "Unable to interpret '{}'.".format(expr))
322 return rows
324 def write_function(stack_fct_used, to_replaces, node):
325 rows = []
326 name, args = node
327 if name != 'FunctionDef':
328 raise RuntimeError( # pragma: no cover
329 "The code being translated should be a single function not "
330 "'{}' at {}.".format(name, self.make_msg(args)))
331 list_args = list(map(str, args['args']))
332 if all(map(lambda s: 'dtype=' not in s, list_args)):
333 list_args.append("dtype=numpy.float32")
334 if all(map(lambda s: 'op_version=' not in s, list_args)):
335 list_args.append("op_version=None")
336 fct_name = args['name']
337 rows.append("def {}({}):".format(
338 fct_name, ', '.join(list_args)))
339 indent = 1
341 to_replace = "# __HEADER__{}".format(id(node))
342 to_replaces.append(to_replace)
343 rows.append("{}{}".format(" " * (indent * 4), to_replace))
345 code = args['code']
346 for op, args in code:
347 if op == "Assign":
348 name = args['name']
349 args = args["args"]
350 rows.append("{}{} = (".format(" " * (indent * 4), name))
351 rows.extend(write_expression(
352 stack_fct_used, args, indent + 1))
353 rows.append("{})".format(" " * (indent * 4)))
354 elif op == "Return":
355 args = args["code"]
356 if output_names is None:
357 rows.append("{}return (".format(" " * (indent * 4)))
358 rows.extend(write_expression(
359 stack_fct_used, args, indent + 1))
360 rows.append("{})".format(" " * (indent * 4)))
361 else:
362 rows.append(
363 "{}return OnnxIdentity(".format(" " * (indent * 4)))
364 subrows = write_expression(
365 stack_fct_used, args, indent + 1)
366 subrows[-1] += ","
367 rows.extend(subrows)
368 rows.append("{}output_names={},".format(
369 " " * ((indent + 1) * 4), str(output_names)))
370 rows.append("{}op_version=op_version".format(
371 " " * ((indent + 1) * 4)))
372 rows.append("{})".format(" " * (indent * 4)))
373 else:
374 raise RuntimeError( # pragma: no cover
375 "Unable to process operator '{}' at {}. "
376 "Make sure it is either an affectation, "
377 "either a return.".format(op, self.make_msg(args)))
378 return rows
380 stack_fct_used = []
381 to_replaces = []
382 rows = write_function(stack_fct_used, to_replaces, self._code_fct)
384 # handling dtype parameter
385 if len(to_replaces) != 1:
386 raise RuntimeError( # pragma: no cover
387 "The following code misses a placeholder:\n{}".format(
388 "\n".join(rows)))
389 index = -1
390 for i, row in enumerate(rows):
391 if to_replaces[0] in row:
392 index = i
393 break
395 header = []
396 for fct in stack_fct_used:
397 header.append(
398 " _{0} = lambda *args, op_version=op_version, **kwargs: {0}(*args, dtype=dtype, "
399 "op_version=op_version, **kwargs)".format(fct))
400 if len(header) > 0:
401 header.append('')
402 rows[index:index + 1] = header
404 return "\n".join(rows)
406 def visit(self, node, info):
407 """
408 Visits a node.
410 @param node visited node
411 @param info info extracted by the visitor
412 """
413 if 'type' not in info:
414 return
416 kind = info['type']
417 if kind == "Module":
418 return
419 if kind == "FunctionDef":
420 if self._is_stacked('FunctionDef'):
421 raise RuntimeError("Nested functions are not allowed at {}.".format(
422 self.make_msg(node)))
423 self._stack.append(
424 ('FunctionDef', {'args': [], 'code': [], 'name': info['name'], 'default': [],
425 'lineno': node.lineno, 'col_offset': node.col_offset}))
426 return
427 if kind == "arguments":
428 _, buf = self._get_last('FunctionDef')
429 return
430 if kind == "arg":
431 return
432 if kind == "Assign":
433 self._stack.append(
434 ('Assign', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
435 return
436 if kind in ('Name', 'Cst'):
437 self._get_last(
438 ('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword', 'UnaryOp'))
439 return
440 if kind == 'BinOp':
441 self._stack.append(
442 ('BinOp', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
443 return
444 if kind == 'UnaryOp':
445 self._stack.append(
446 ('UnaryOp', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
447 return
448 if kind in OnnxTranslator._binary_operators:
449 _, buf = self._get_last(('BinOp', 'UnaryOp'))
450 buf['op'] = kind
451 return
452 if kind == 'Call':
453 self._stack.append(
454 ('Call', {'name': info['str'], 'args': [], 'lineno': node.lineno,
455 'col_offset': node.col_offset}))
456 return
457 if kind == 'Return':
458 self._get_last('FunctionDef')
459 self._stack.append(
460 ('Return', {'code': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
461 return
462 if kind == "Attribute":
463 if info.get('str', '') == 'T':
464 raise NotImplementedError( # pragma: no cover
465 "Transpose should be done with numpy.transpose not with .T'{}' "
466 "at {}\n{}\n---\n{}".format(
467 info.get('type', '?'), self.make_msg(node),
468 pprint.pformat(info), pprint.pformat(self._stack)))
469 self._get_last('Call')
470 return
471 if kind == 'keyword':
472 self._get_last('Call')
473 self._stack.append(
474 ('keyword', {'name': "{0}".format(node.arg),
475 'lineno': getattr(node, 'lineno', '?'),
476 'col_offset': getattr(node, 'col_offset', '?')}))
477 return
478 if kind == 'List':
479 self._get_last('keyword')
480 self._stack.append(
481 ('List', {'elts': [], 'lineno': getattr(node, 'lineno', '?'),
482 'col_offset': getattr(node, 'col_offset', '?')}))
483 return
484 if kind == 'Num':
485 self._get_last(('List', 'UnaryOp', 'BinOp', 'FunctionDef', 'Call'))
486 return
487 if kind == 'Str':
488 self._get_last('keyword')
489 return
491 raise NotImplementedError( # pragma: no cover
492 "Unable to interpret kind '{}' at {}\n{}\n---\n{}".format(
493 info.get('type', '?'), self.make_msg(
494 node), pprint.pformat(info),
495 pprint.pformat(self._stack)))
497 def _fix_default_values(self, code_fct):
498 """
499 Maps default values with parameter names.
500 """
501 nbdef = len(code_fct[1]['default'])
502 nbpar = len(code_fct[1]['args'])
503 args = []
504 for i in range(nbpar):
505 name, annotation = code_fct[1]['args'][i]
506 j = nbdef - (nbpar - i)
507 if j >= 0:
508 default = code_fct[1]['default'][j]
509 p = OnnxTranslator.Parameter(
510 name, annotation=annotation, value=default)
511 else:
512 p = OnnxTranslator.Parameter(name, annotation=annotation)
513 args.append(p)
514 code_fct[1]['args'] = args
516 def _post_process(self, op, node):
517 """
518 Simplifies some operator such as ``OnnxNeg(2)``.
519 """
520 if op is None and 'args' in node:
521 for i in range(len(node['args'])):
522 if not isinstance(node['args'][i], tuple):
523 continue
524 o, v = node['args'][i]
525 if (o == 'UnaryOp' and len(v['args']) == 1 and
526 isinstance(v['args'][0], (int, float, numpy.int64,
527 numpy.float32, numpy.float64))):
528 if v['op'] == 'Sub':
529 node['args'][i] = -v['args'][0]
531 def depart(self, node, info):
532 """
533 Visits a node.
535 @param node visited node
536 @param info info extracted by the visitor
537 """
538 if 'type' not in info:
539 return
541 kind = info['type']
542 if kind == "arg":
543 return
544 if kind == "arguments":
545 _, buf = self._get_last('FunctionDef')
546 for child in info['children']:
547 if child['type'] == 'Str':
548 buf['default'].append(child['str'])
549 elif child['type'] in ('Num', 'Cst'):
550 buf['default'].append(child['n'])
551 elif child['type'] == 'arg':
552 buf['args'].append(
553 (child['str'], child.get('annotation', None)))
554 else:
555 raise RuntimeError( # pragma: no cover
556 "Unable to interpret type '{}' in function definition."
557 "\n{}".format(
558 child['type'], pprint.pformat(info)))
559 return
561 if kind == "Name":
562 op, buf = self._get_last(
563 ('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword',
564 'UnaryOp'),
565 info)
566 if op == 'Assign':
567 buf['name'] = info['str']
568 return
569 elif op in ('BinOp', 'Call'):
570 buf['args'].append(info['str'])
571 return
572 elif op == 'Return':
573 buf['code'] = info['str']
574 return
575 elif op == 'keyword':
576 buf['value'] = info['str']
577 return
578 elif op == 'UnaryOp':
579 buf['args'].append(info['str'])
580 return
581 elif op == 'FunctionDef':
582 raise RuntimeError("Default value must be constant, variable '{}' was "
583 "detected.".format(info['str']))
585 if kind in OnnxTranslator._binary_operators:
586 _, buf = self._get_last(('BinOp', 'UnaryOp'))
587 return
588 if kind in ('Call', 'BinOp', 'Assign', 'Return', 'UnaryOp'):
589 op, buf = self._get_last(
590 ('Call', 'BinOp', 'Assign', 'Return', 'UnaryOp'))
591 self._post_process(op, buf)
592 self._stack.pop()
593 opp, parent = self._get_last(
594 ('Call', 'BinOp', 'Assign', 'FunctionDef', 'Return', 'UnaryOp'))
595 if opp in ('FunctionDef', 'Return'):
596 parent['code'].append((op, buf))
597 else:
598 parent['args'].append((op, buf))
599 self._post_process(None, parent)
600 return
601 if kind == 'FunctionDef':
602 if len(self._stack) == 1:
603 self._code_fct = self._stack[-1]
604 self._fix_default_values(self._code_fct)
605 self._stack = []
606 return
607 if kind == 'Module':
608 return
609 if kind == 'Attribute':
610 op, buf = self._get_last(('Call', 'BinOp'))
612 if len(info["children"]) > 0:
613 fir = info["children"][0]
614 if fir["type"] == "Name":
615 parent = fir["node"].id
616 info["str"] = "{0}.{1}".format(parent, info["str"])
617 info["children"][0]["remove"] = True
619 buf['name'] = info["str"]
620 buf['args'][0] = info["str"]
621 return
622 if kind in ('Num', 'Cst'):
623 op, buf = self._get_last(
624 ('List', 'BinOp', 'UnaryOp', 'FunctionDef', 'Call'))
625 if op == 'FunctionDef':
626 return
627 if op == 'List':
628 buf['elts'].append(info['n'])
629 else:
630 buf['args'].append(info['n'])
631 return
632 if kind == 'Str':
633 _, buf = self._get_last('keyword')
634 buf['value'] = info['str']
635 return
636 if kind == 'List':
637 op, buf = self._get_last('List')
638 value = buf['elts']
639 self._post_process(op, buf)
640 self._stack.pop()
641 opp, parent = self._get_last('keyword')
642 parent['value'] = value
643 self._post_process(None, parent)
644 return
645 if kind == 'keyword':
646 op, buf = self._get_last('keyword')
647 name = buf["name"]
648 if 'value' not in buf:
649 raise RuntimeError(str(buf)) # pragma: no cover
650 value = buf['value']
651 self._post_process(op, buf)
652 self._stack.pop()
653 opp, parent = self._get_last('Call')
654 parent['args'].append(OnnxTranslator.Parameter(name, value))
655 self._post_process(None, parent)
656 return
658 raise NotImplementedError( # pragma: no cover
659 "Unable to interpret kind '{}' at {}\n{}\n---\n{}".format(
660 info.get('type', '?'), self.make_msg(
661 node), pprint.pformat(info),
662 pprint.pformat(self._stack)))