Coverage for mlprodict/onnx_tools/onnx_grammar/node_visitor_translator.py: 95%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief One class which visits a syntax tree.
4"""
6import ast
7from .onnx_translator import OnnxTranslator
10class CodeNodeVisitor(ast.NodeVisitor):
12 """
13 Defines a visitor which walks though the syntax tree of the code.
15 .. exref::
16 :title: Get the tree of a simple function
18 The following code uses Python syntax but follows a SQL logic.
20 .. runpython::
21 :showcode:
22 :warningout: DeprecationWarning
23 :process:
24 :store_in_file: fct2onnx1.py
26 import ast
27 import inspect
28 from textwrap import dedent
29 from mlprodict.onnx_tools.onnx_grammar import CodeNodeVisitor
31 def norm2(x, y):
32 delta = x - y
33 n = delta ** 2
34 return n
36 code = dedent(inspect.getsource(norm2))
37 node = ast.parse(code)
38 v = CodeNodeVisitor()
39 v.visit(node)
40 for r in v.Rows :
41 print("{0}{1}: {2}".format(" " * r["indent"], r["type"], r["str"]))
42 """
44 def __init__(self, translator=None):
45 """
46 @param translator @see cl CodeTranslator
48 By default the translator is @see cl OnnxTranslator.
49 """
50 ast.NodeVisitor.__init__(self)
51 self._rows = []
52 self._indent = 0
53 self._stack = []
54 self._translator = OnnxTranslator(
55 self) if translator is None else translator
57 def push(self, row):
58 """
59 Pushes an element into a list.
60 """
61 self._rows.append(row)
63 def generic_visit(self, node):
64 """
65 Overrides ``generic_visit`` to check it is not used.
66 """
67 raise AttributeError( # pragma: no cover
68 "generic_visit_args should be used.")
70 def generic_visit_args(self, node, row):
71 """
72 Overrides ``generic_visit`` to keep track of the indentation
73 and the node parent. The function will add field
74 ``row["children"] = visited`` nodes from here.
76 @param node node which needs to be visited
77 @param row row (a dictionary)
78 @return See ``ast.NodeVisitor.generic_visit``
79 """
80 if hasattr(node, 'lineno'):
81 row['lineno'] = node.lineno
82 if hasattr(node, 'col_offset'):
83 row['col_offset'] = node.col_offset
84 self._indent += 1
85 last = len(self._rows)
86 self._translator.visit(node, row)
87 res = ast.NodeVisitor.generic_visit( # pylint: disable=E1111
88 self, node) # pylint: disable=E1111
89 row["children"] = [
90 _ for _ in self._rows[
91 last:] if _["indent"] == self._indent]
92 self._indent -= 1
93 self._translator.depart(node, row)
94 return res
96 def make_msg(self, node):
97 """
98 Displays line and column information into a string.
99 """
100 return "line {}, col {}".format( # pragma: no cover
101 getattr(node, 'lineno', '?'), getattr(node, 'col_offset', '?'))
103 def visit(self, node):
104 """
105 Visits a node, a method must exist for every object class.
106 """
107 method = 'visit_' + node.__class__.__name__
108 visitor = getattr(self, method, None)
109 if visitor is None:
110 raise TypeError( # pragma: no cover
111 "Unable to find a method '{}' at {}.".format(
112 method, self.make_msg(node)))
113 res = visitor(node)
114 # print(method, CodeNodeVisitor.print_node(node))
115 return res
117 def visit_(self, node):
118 """
119 If an element is not found...
120 """
121 raise NotImplementedError( # pragma: no cover
122 "Node '{}' ({}) not recognized at {}\nNode\n{}\n--"
123 "Status--\n{}".format(
124 node, type(node), self.make_msg(node),
125 self.print_node(node), self.print_tree()))
127 @staticmethod
128 def print_node(node):
129 """
130 Debugging purpose.
131 """
132 r = []
133 for att in sorted(set(["s", "name", "str", "id", "body", "n",
134 "arg", "targets", "attr", "returns", "ctx",
135 'col_offset', 'lineno',
136 'value'] + list(getattr(node, '_attributes', [])))):
137 v = getattr(node, att, None)
138 if v is not None or att in getattr(node, '_fields', []):
139 r.append("{0}={1}".format(att, v))
140 return " ".join(r)
142 def print_tree(self):
143 """
144 Displays the tree of instructions.
146 @return string
147 """
148 rows = []
149 for r in self.Rows:
150 rows.append(
151 ("{0}{1}: {2}".format(
152 " " *
153 r["indent"],
154 r["type"],
155 r["str"])))
156 return "\n".join(rows)
158 @property
159 def Rows(self):
160 """
161 returns a list of dictionaries with all the elements of the code
162 """
163 return [_ for _ in self._rows if not _.get("remove", False)]
165 def export(self, context=None, **kwargs):
166 """
167 Calls method *export* from the translator class.
169 @param context known :epkg:`python` needed to run
170 the translated function
171 @param kwargs whatever the method *export* from
172 the translator class ingests
173 @return whatever the method *export* from
174 the translator class returns
175 """
176 return self._translator.export(context=context, **kwargs)
178 ###########
179 # Methods for python code elements
180 ###########
182 def visit_Str(self, node): # pylint: disable=C0111
183 cont = {
184 "indent": self._indent,
185 "type": "Str",
186 "str": node.s,
187 "node": node,
188 "value": node.s}
189 self.push(cont)
190 return self.generic_visit_args(node, cont)
192 def visit_Name(self, node): # pylint: disable=C0111
193 cont = {
194 "indent": self._indent,
195 "type": "Name",
196 "str": node.id,
197 "node": node,
198 "id": node.id,
199 "ctx": node.ctx}
200 self.push(cont)
201 return self.generic_visit_args(node, cont)
203 def visit_Module(self, node): # pylint: disable=C0111
204 cont = {
205 "indent": self._indent,
206 "type": "Module",
207 "str": "",
208 "body": node.body,
209 "node": node}
210 self.push(cont)
211 return self.generic_visit_args(node, cont)
213 def visit_FunctionDef(self, node): # pylint: disable=C0111
214 cont = {"indent": self._indent, "type": "FunctionDef", "str": node.name, "name": node.name, "body": node.body,
215 "node": node, "returns": node.returns}
216 self.push(cont)
217 return self.generic_visit_args(node, cont)
219 def visit_List(self, node): # pylint: disable=C0111
220 cont = {"indent": self._indent, "type": "List",
221 "str": "", "elts": node.elts,
222 "node": node}
223 self.push(cont)
224 return self.generic_visit_args(node, cont)
226 def visit_arguments(self, node): # pylint: disable=C0111
227 cont = {"indent": self._indent, "type": "arguments", "str": "",
228 "node": node, "args": node.args}
229 self.push(cont)
230 return self.generic_visit_args(node, cont)
232 def visit_arg(self, node): # pylint: disable=C0111
233 cont = {"indent": self._indent, "type": "arg", "str": node.arg,
234 "node": node,
235 "arg": node.arg, "annotation": node.annotation}
236 self.push(cont)
237 return self.generic_visit_args(node, cont)
239 def visit_Assign(self, node): # pylint: disable=C0111
240 cont = {"indent": self._indent, "type": "Assign", "str": "", "node": node,
241 "targets": node.targets, "value": node.value}
242 self.push(cont)
243 return self.generic_visit_args(node, cont)
245 def visit_Store(self, node): # pylint: disable=C0111
246 #cont = { "indent":self._indent, "type": "Store", "str": "" }
247 # self.push(cont)
248 cont = {}
249 return self.generic_visit_args(node, cont)
251 def visit_Call(self, node): # pylint: disable=C0111
252 if "attr" in node.func.__dict__:
253 cont = {"indent": self._indent, "type": "Call", "str": node.func.attr,
254 "node": node, "func": node.func}
255 else:
256 cont = {"indent": self._indent, "type": "Call", "str": node.func.id,
257 "node": node, "func": node.func}
258 self.push(cont)
259 return self.generic_visit_args(node, cont)
261 def visit_Attribute(self, node): # pylint: disable=C0111
262 cont = {"indent": self._indent, "type": "Attribute", "str": node.attr,
263 "node": node, "value": node.value, "ctx": node.ctx, "attr": node.attr}
264 self.push(cont)
265 # last = len(self._rows)
266 res = self.generic_visit_args(node, cont)
268 if len(cont["children"]) > 0:
269 fir = cont["children"][0]
270 if fir["type"] == "Name":
271 parent = fir["node"].id
272 cont["str"] = "{0}.{1}".format(parent, cont["str"])
273 cont["children"][0]["remove"] = True
274 return res
276 def visit_Load(self, node): # pylint: disable=C0111
277 cont = {}
278 return self.generic_visit_args(node, cont)
280 def visit_keyword(self, node): # pylint: disable=C0111
281 cont = {"indent": self._indent, "type": "keyword", "str": "{0}".format(node.arg),
282 "node": node, "arg": node.arg, "value": node.value}
283 self.push(cont)
284 return self.generic_visit_args(node, cont)
286 def visit_BinOp(self, node): # pylint: disable=C0111
287 cont = {"indent": self._indent, "type": "BinOp",
288 "str": "", "node": node}
289 self.push(cont)
290 return self.generic_visit_args(node, cont)
292 def visit_Div(self, node): # pylint: disable=C0111
293 cont = {"indent": self._indent, "type": "Div",
294 "str": "", "node": node}
295 self.push(cont)
296 return self.generic_visit_args(node, cont)
298 def visit_Sub(self, node): # pylint: disable=C0111
299 cont = {"indent": self._indent, "type": "Sub",
300 "str": "", "node": node}
301 self.push(cont)
302 return self.generic_visit_args(node, cont)
304 def visit_USub(self, node): # pylint: disable=C0111
305 cont = {"indent": self._indent, "type": "Sub",
306 "str": "", "node": node}
307 self.push(cont)
308 return self.generic_visit_args(node, cont)
310 def visit_Add(self, node): # pylint: disable=C0111
311 cont = {"indent": self._indent, "type": "Add",
312 "str": "", "node": node}
313 self.push(cont)
314 return self.generic_visit_args(node, cont)
316 def visit_Pow(self, node): # pylint: disable=C0111
317 cont = {"indent": self._indent, "type": "Pow",
318 "str": "", "node": node}
319 self.push(cont)
320 return self.generic_visit_args(node, cont)
322 def visit_Mult(self, node): # pylint: disable=C0111
323 cont = {"indent": self._indent, "type": "Mult",
324 "str": "", "node": node}
325 self.push(cont)
326 return self.generic_visit_args(node, cont)
328 def visit_MatMult(self, node): # pylint: disable=C0111
329 cont = {"indent": self._indent, "type": "MatMult",
330 "str": "", "node": node}
331 self.push(cont)
332 return self.generic_visit_args(node, cont)
334 def visit_Compare(self, node): # pylint: disable=C0111
335 cont = {"indent": self._indent, "type": "Compare",
336 "str": "", "node": node}
337 self.push(cont)
338 return self.generic_visit_args(node, cont)
340 def visit_Gt(self, node): # pylint: disable=C0111
341 cont = {"indent": self._indent, "type": "Gt", "str": "", "node": node}
342 self.push(cont)
343 return self.generic_visit_args(node, cont)
345 def visit_Lt(self, node): # pylint: disable=C0111
346 cont = {"indent": self._indent, "type": "Lt", "str": "", "node": node}
347 self.push(cont)
348 return self.generic_visit_args(node, cont)
350 def visit_UnaryOp(self, node): # pylint: disable=C0111
351 cont = {"indent": self._indent,
352 "type": "UnaryOp", "str": "", "node": node}
353 self.push(cont)
354 return self.generic_visit_args(node, cont)
356 def visit_Num(self, node): # pylint: disable=C0111
357 cont = {"indent": self._indent, "type": "Num",
358 "node": node, "str": "{0}".format(node.n),
359 'n': node.n}
360 self.push(cont)
361 return self.generic_visit_args(node, cont)
363 def visit_Return(self, node): # pylint: disable=C0111
364 cont = {"indent": self._indent, "type": "Return", "node": node, "str": "",
365 'value': node.value}
366 self.push(cont)
367 return self.generic_visit_args(node, cont)
369 def visit_NameConstant(self, node):
370 """
371 A name.
372 """
373 if node.value is None:
374 cont = {"indent": self._indent, "type": "Cst",
375 "node": node, "str": "None",
376 'n': None}
377 self.push(cont)
378 return self.generic_visit_args(node, cont)
379 return self.visit_(node) # pragma: no cover