Coverage for mlprodict/onnx_tools/onnx_grammar/node_visitor_translator.py: 95%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

171 statements  

1""" 

2@file 

3@brief One class which visits a syntax tree. 

4""" 

5 

6import ast 

7from .onnx_translator import OnnxTranslator 

8 

9 

10class CodeNodeVisitor(ast.NodeVisitor): 

11 

12 """ 

13 Defines a visitor which walks though the syntax tree of the code. 

14 

15 .. exref:: 

16 :title: Get the tree of a simple function 

17 

18 The following code uses Python syntax but follows a SQL logic. 

19 

20 .. runpython:: 

21 :showcode: 

22 :warningout: DeprecationWarning 

23 :process: 

24 :store_in_file: fct2onnx1.py 

25 

26 import ast 

27 import inspect 

28 from textwrap import dedent 

29 from mlprodict.onnx_tools.onnx_grammar import CodeNodeVisitor 

30 

31 def norm2(x, y): 

32 delta = x - y 

33 n = delta ** 2 

34 return n 

35 

36 code = dedent(inspect.getsource(norm2)) 

37 node = ast.parse(code) 

38 v = CodeNodeVisitor() 

39 v.visit(node) 

40 for r in v.Rows : 

41 print("{0}{1}: {2}".format(" " * r["indent"], r["type"], r["str"])) 

42 """ 

43 

44 def __init__(self, translator=None): 

45 """ 

46 @param translator @see cl CodeTranslator 

47 

48 By default the translator is @see cl OnnxTranslator. 

49 """ 

50 ast.NodeVisitor.__init__(self) 

51 self._rows = [] 

52 self._indent = 0 

53 self._stack = [] 

54 self._translator = OnnxTranslator( 

55 self) if translator is None else translator 

56 

57 def push(self, row): 

58 """ 

59 Pushes an element into a list. 

60 """ 

61 self._rows.append(row) 

62 

63 def generic_visit(self, node): 

64 """ 

65 Overrides ``generic_visit`` to check it is not used. 

66 """ 

67 raise AttributeError( # pragma: no cover 

68 "generic_visit_args should be used.") 

69 

70 def generic_visit_args(self, node, row): 

71 """ 

72 Overrides ``generic_visit`` to keep track of the indentation 

73 and the node parent. The function will add field 

74 ``row["children"] = visited`` nodes from here. 

75 

76 @param node node which needs to be visited 

77 @param row row (a dictionary) 

78 @return See ``ast.NodeVisitor.generic_visit`` 

79 """ 

80 if hasattr(node, 'lineno'): 

81 row['lineno'] = node.lineno 

82 if hasattr(node, 'col_offset'): 

83 row['col_offset'] = node.col_offset 

84 self._indent += 1 

85 last = len(self._rows) 

86 self._translator.visit(node, row) 

87 res = ast.NodeVisitor.generic_visit( # pylint: disable=E1111 

88 self, node) # pylint: disable=E1111 

89 row["children"] = [ 

90 _ for _ in self._rows[ 

91 last:] if _["indent"] == self._indent] 

92 self._indent -= 1 

93 self._translator.depart(node, row) 

94 return res 

95 

96 def make_msg(self, node): 

97 """ 

98 Displays line and column information into a string. 

99 """ 

100 return "line {}, col {}".format( # pragma: no cover 

101 getattr(node, 'lineno', '?'), getattr(node, 'col_offset', '?')) 

102 

103 def visit(self, node): 

104 """ 

105 Visits a node, a method must exist for every object class. 

106 """ 

107 method = 'visit_' + node.__class__.__name__ 

108 visitor = getattr(self, method, None) 

109 if visitor is None: 

110 raise TypeError( # pragma: no cover 

111 "Unable to find a method '{}' at {}.".format( 

112 method, self.make_msg(node))) 

113 res = visitor(node) 

114 # print(method, CodeNodeVisitor.print_node(node)) 

115 return res 

116 

117 def visit_(self, node): 

118 """ 

119 If an element is not found... 

120 """ 

121 raise NotImplementedError( # pragma: no cover 

122 "Node '{}' ({}) not recognized at {}\nNode\n{}\n--" 

123 "Status--\n{}".format( 

124 node, type(node), self.make_msg(node), 

125 self.print_node(node), self.print_tree())) 

126 

127 @staticmethod 

128 def print_node(node): 

129 """ 

130 Debugging purpose. 

131 """ 

132 r = [] 

133 for att in sorted(set(["s", "name", "str", "id", "body", "n", 

134 "arg", "targets", "attr", "returns", "ctx", 

135 'col_offset', 'lineno', 

136 'value'] + list(getattr(node, '_attributes', [])))): 

137 v = getattr(node, att, None) 

138 if v is not None or att in getattr(node, '_fields', []): 

139 r.append("{0}={1}".format(att, v)) 

140 return " ".join(r) 

141 

142 def print_tree(self): 

143 """ 

144 Displays the tree of instructions. 

145 

146 @return string 

147 """ 

148 rows = [] 

149 for r in self.Rows: 

150 rows.append( 

151 ("{0}{1}: {2}".format( 

152 " " * 

153 r["indent"], 

154 r["type"], 

155 r["str"]))) 

156 return "\n".join(rows) 

157 

158 @property 

159 def Rows(self): 

160 """ 

161 returns a list of dictionaries with all the elements of the code 

162 """ 

163 return [_ for _ in self._rows if not _.get("remove", False)] 

164 

165 def export(self, context=None, **kwargs): 

166 """ 

167 Calls method *export* from the translator class. 

168 

169 @param context known :epkg:`python` needed to run 

170 the translated function 

171 @param kwargs whatever the method *export* from 

172 the translator class ingests 

173 @return whatever the method *export* from 

174 the translator class returns 

175 """ 

176 return self._translator.export(context=context, **kwargs) 

177 

178 ########### 

179 # Methods for python code elements 

180 ########### 

181 

182 def visit_Str(self, node): # pylint: disable=C0111 

183 cont = { 

184 "indent": self._indent, 

185 "type": "Str", 

186 "str": node.s, 

187 "node": node, 

188 "value": node.s} 

189 self.push(cont) 

190 return self.generic_visit_args(node, cont) 

191 

192 def visit_Name(self, node): # pylint: disable=C0111 

193 cont = { 

194 "indent": self._indent, 

195 "type": "Name", 

196 "str": node.id, 

197 "node": node, 

198 "id": node.id, 

199 "ctx": node.ctx} 

200 self.push(cont) 

201 return self.generic_visit_args(node, cont) 

202 

203 def visit_Module(self, node): # pylint: disable=C0111 

204 cont = { 

205 "indent": self._indent, 

206 "type": "Module", 

207 "str": "", 

208 "body": node.body, 

209 "node": node} 

210 self.push(cont) 

211 return self.generic_visit_args(node, cont) 

212 

213 def visit_FunctionDef(self, node): # pylint: disable=C0111 

214 cont = {"indent": self._indent, "type": "FunctionDef", "str": node.name, "name": node.name, "body": node.body, 

215 "node": node, "returns": node.returns} 

216 self.push(cont) 

217 return self.generic_visit_args(node, cont) 

218 

219 def visit_List(self, node): # pylint: disable=C0111 

220 cont = {"indent": self._indent, "type": "List", 

221 "str": "", "elts": node.elts, 

222 "node": node} 

223 self.push(cont) 

224 return self.generic_visit_args(node, cont) 

225 

226 def visit_arguments(self, node): # pylint: disable=C0111 

227 cont = {"indent": self._indent, "type": "arguments", "str": "", 

228 "node": node, "args": node.args} 

229 self.push(cont) 

230 return self.generic_visit_args(node, cont) 

231 

232 def visit_arg(self, node): # pylint: disable=C0111 

233 cont = {"indent": self._indent, "type": "arg", "str": node.arg, 

234 "node": node, 

235 "arg": node.arg, "annotation": node.annotation} 

236 self.push(cont) 

237 return self.generic_visit_args(node, cont) 

238 

239 def visit_Assign(self, node): # pylint: disable=C0111 

240 cont = {"indent": self._indent, "type": "Assign", "str": "", "node": node, 

241 "targets": node.targets, "value": node.value} 

242 self.push(cont) 

243 return self.generic_visit_args(node, cont) 

244 

245 def visit_Store(self, node): # pylint: disable=C0111 

246 #cont = { "indent":self._indent, "type": "Store", "str": "" } 

247 # self.push(cont) 

248 cont = {} 

249 return self.generic_visit_args(node, cont) 

250 

251 def visit_Call(self, node): # pylint: disable=C0111 

252 if "attr" in node.func.__dict__: 

253 cont = {"indent": self._indent, "type": "Call", "str": node.func.attr, 

254 "node": node, "func": node.func} 

255 else: 

256 cont = {"indent": self._indent, "type": "Call", "str": node.func.id, 

257 "node": node, "func": node.func} 

258 self.push(cont) 

259 return self.generic_visit_args(node, cont) 

260 

261 def visit_Attribute(self, node): # pylint: disable=C0111 

262 cont = {"indent": self._indent, "type": "Attribute", "str": node.attr, 

263 "node": node, "value": node.value, "ctx": node.ctx, "attr": node.attr} 

264 self.push(cont) 

265 # last = len(self._rows) 

266 res = self.generic_visit_args(node, cont) 

267 

268 if len(cont["children"]) > 0: 

269 fir = cont["children"][0] 

270 if fir["type"] == "Name": 

271 parent = fir["node"].id 

272 cont["str"] = "{0}.{1}".format(parent, cont["str"]) 

273 cont["children"][0]["remove"] = True 

274 return res 

275 

276 def visit_Load(self, node): # pylint: disable=C0111 

277 cont = {} 

278 return self.generic_visit_args(node, cont) 

279 

280 def visit_keyword(self, node): # pylint: disable=C0111 

281 cont = {"indent": self._indent, "type": "keyword", "str": "{0}".format(node.arg), 

282 "node": node, "arg": node.arg, "value": node.value} 

283 self.push(cont) 

284 return self.generic_visit_args(node, cont) 

285 

286 def visit_BinOp(self, node): # pylint: disable=C0111 

287 cont = {"indent": self._indent, "type": "BinOp", 

288 "str": "", "node": node} 

289 self.push(cont) 

290 return self.generic_visit_args(node, cont) 

291 

292 def visit_Div(self, node): # pylint: disable=C0111 

293 cont = {"indent": self._indent, "type": "Div", 

294 "str": "", "node": node} 

295 self.push(cont) 

296 return self.generic_visit_args(node, cont) 

297 

298 def visit_Sub(self, node): # pylint: disable=C0111 

299 cont = {"indent": self._indent, "type": "Sub", 

300 "str": "", "node": node} 

301 self.push(cont) 

302 return self.generic_visit_args(node, cont) 

303 

304 def visit_USub(self, node): # pylint: disable=C0111 

305 cont = {"indent": self._indent, "type": "Sub", 

306 "str": "", "node": node} 

307 self.push(cont) 

308 return self.generic_visit_args(node, cont) 

309 

310 def visit_Add(self, node): # pylint: disable=C0111 

311 cont = {"indent": self._indent, "type": "Add", 

312 "str": "", "node": node} 

313 self.push(cont) 

314 return self.generic_visit_args(node, cont) 

315 

316 def visit_Pow(self, node): # pylint: disable=C0111 

317 cont = {"indent": self._indent, "type": "Pow", 

318 "str": "", "node": node} 

319 self.push(cont) 

320 return self.generic_visit_args(node, cont) 

321 

322 def visit_Mult(self, node): # pylint: disable=C0111 

323 cont = {"indent": self._indent, "type": "Mult", 

324 "str": "", "node": node} 

325 self.push(cont) 

326 return self.generic_visit_args(node, cont) 

327 

328 def visit_MatMult(self, node): # pylint: disable=C0111 

329 cont = {"indent": self._indent, "type": "MatMult", 

330 "str": "", "node": node} 

331 self.push(cont) 

332 return self.generic_visit_args(node, cont) 

333 

334 def visit_Compare(self, node): # pylint: disable=C0111 

335 cont = {"indent": self._indent, "type": "Compare", 

336 "str": "", "node": node} 

337 self.push(cont) 

338 return self.generic_visit_args(node, cont) 

339 

340 def visit_Gt(self, node): # pylint: disable=C0111 

341 cont = {"indent": self._indent, "type": "Gt", "str": "", "node": node} 

342 self.push(cont) 

343 return self.generic_visit_args(node, cont) 

344 

345 def visit_Lt(self, node): # pylint: disable=C0111 

346 cont = {"indent": self._indent, "type": "Lt", "str": "", "node": node} 

347 self.push(cont) 

348 return self.generic_visit_args(node, cont) 

349 

350 def visit_UnaryOp(self, node): # pylint: disable=C0111 

351 cont = {"indent": self._indent, 

352 "type": "UnaryOp", "str": "", "node": node} 

353 self.push(cont) 

354 return self.generic_visit_args(node, cont) 

355 

356 def visit_Num(self, node): # pylint: disable=C0111 

357 cont = {"indent": self._indent, "type": "Num", 

358 "node": node, "str": "{0}".format(node.n), 

359 'n': node.n} 

360 self.push(cont) 

361 return self.generic_visit_args(node, cont) 

362 

363 def visit_Return(self, node): # pylint: disable=C0111 

364 cont = {"indent": self._indent, "type": "Return", "node": node, "str": "", 

365 'value': node.value} 

366 self.push(cont) 

367 return self.generic_visit_args(node, cont) 

368 

369 def visit_NameConstant(self, node): 

370 """ 

371 A name. 

372 """ 

373 if node.value is None: 

374 cont = {"indent": self._indent, "type": "Cst", 

375 "node": node, "str": "None", 

376 'n': None} 

377 self.push(cont) 

378 return self.generic_visit_args(node, cont) 

379 return self.visit_(node) # pragma: no cover