Coverage for mlprodict/testing/verify_code.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

333 statements  

1""" 

2@file 

3@brief Looks into the code and detects error 

4before finalizing the benchmark. 

5""" 

6import ast 

7import collections 

8import inspect 

9import numpy 

10 

11 

12class ImperfectPythonCode(RuntimeError): 

13 """ 

14 Raised if the code shows errors. 

15 """ 

16 pass 

17 

18 

19def verify_code(source, exc=True): 

20 """ 

21 Verifies :epkg:`python` code. 

22 

23 :param source: source to look into 

24 :param exc: raise an exception or return the list of 

25 missing identifiers 

26 :return: tuple(missing identifiers, :class:`CodeNodeVisitor 

27 <mlprodict.onnx_tools.onnx_grammar.node_visitor_translator>`) 

28 """ 

29 node = ast.parse(source) 

30 v = CodeNodeVisitor() 

31 v.visit(node) 

32 assign = v._assign 

33 imports = v._imports 

34 names = v._names 

35 args = v._args 

36 known = {'super': None, 'ImportError': None, 'print': print, 

37 'classmethod': classmethod, 'numpy': numpy, 

38 'dict': dict, 'list': list, 'sorted': sorted, 'len': len, 

39 'collections': collections, 'inspect': inspect, 'range': range, 

40 'int': int, 'str': str, 'isinstance': isinstance} 

41 for kn in imports: 

42 known[kn[0]] = kn 

43 for kn in assign: 

44 known[kn[0]] = kn 

45 for kn in args: 

46 known[kn[0]] = kn 

47 issues = set() 

48 for name in names: 

49 if name[0] not in known: 

50 issues.add(name[0]) 

51 if exc and len(issues) > 0: 

52 raise ImperfectPythonCode( 

53 "Unknown identifiers: {} in source\n{}".format( 

54 issues, source)) 

55 return issues, v 

56 

57 

58class CodeNodeVisitor(ast.NodeVisitor): 

59 """ 

60 Visits the code, implements verification rules. 

61 """ 

62 

63 def __init__(self): 

64 ast.NodeVisitor.__init__(self) 

65 self._rows = [] 

66 self._indent = 0 

67 self._stack = [] 

68 self._imports = [] 

69 self._names = [] 

70 self._alias = [] 

71 self._assign = [] 

72 self._args = [] 

73 self._fits = [] 

74 

75 def push(self, row): 

76 """ 

77 Pushes an element into a list. 

78 """ 

79 self._rows.append(row) 

80 

81 def generic_visit(self, node): 

82 """ 

83 Overrides ``generic_visit`` to check it is not used. 

84 """ 

85 raise AttributeError( 

86 "generic_visit_args should be used.") # pragma: no cover 

87 

88 def generic_visit_args(self, node, row): 

89 """ 

90 Overrides ``generic_visit`` to keep track of the indentation 

91 and the node parent. The function will add field 

92 ``row["children"] = visited`` nodes from here. 

93 

94 @param node node which needs to be visited 

95 @param row row (a dictionary) 

96 @return See ``ast.NodeVisitor.generic_visit`` 

97 """ 

98 self._indent += 1 

99 last = len(self._rows) 

100 res = ast.NodeVisitor.generic_visit( # pylint: disable=E1111 

101 self, node) # pylint: disable=E1111 

102 row["children"] = [ 

103 _ for _ in self._rows[ 

104 last:] if _["indent"] == self._indent] 

105 self._indent -= 1 

106 return res 

107 

108 def visit(self, node): 

109 """ 

110 Visits a node, a method must exist for every object class. 

111 """ 

112 method = 'visit_' + node.__class__.__name__ 

113 visitor = getattr(self, method, None) 

114 if visitor is None: 

115 if method.startswith("visit_"): 

116 cont = { 

117 "indent": self._indent, 

118 "str": method[6:], 

119 "node": node} 

120 self.push(cont) 

121 return self.generic_visit_args(node, cont) 

122 raise TypeError("unable to find a method: " + 

123 method) # pragma: no cover 

124 res = visitor(node) 

125 # print(method, CodeNodeVisitor.print_node(node)) 

126 return res 

127 

128 @staticmethod 

129 def print_node(node): 

130 """ 

131 Debugging purpose. 

132 """ 

133 r = [] 

134 for att in ["s", "name", "str", "id", "body", "n", 

135 "arg", "targets", "attr", "returns", "ctx"]: 

136 if att in node.__dict__: 

137 r.append("{0}={1}".format(att, str(node.__dict__[att]))) 

138 return " ".join(r) 

139 

140 def print_tree(self): # pylint: disable=C0116 

141 """ 

142 Displays the tree of instructions. 

143 

144 @return string 

145 """ 

146 rows = [] 

147 for r in self.Rows: 

148 rows.append( 

149 ("{0}{1}: {2}".format( 

150 " " * 

151 r["indent"], 

152 r.get("type", ''), 

153 r.get("str", '')))) 

154 return "\n".join(rows) 

155 

156 @property 

157 def Rows(self): 

158 """ 

159 returns a list of dictionaries with all the elements of the code 

160 """ 

161 return [_ for _ in self._rows if not _.get("remove", False)] 

162 

163 def visit_Str(self, node): # pylint: disable=C0116 

164 cont = { 

165 "indent": self._indent, 

166 "type": "Str", 

167 "str": node.s, 

168 "node": node, 

169 "value": node.s} 

170 self.push(cont) 

171 return self.generic_visit_args(node, cont) 

172 

173 def visit_Name(self, node): # pylint: disable=C0116 

174 cont = { 

175 "indent": self._indent, 

176 "type": "Name", 

177 "str": node.id, 

178 "node": node, 

179 "id": node.id, 

180 "ctx": node.ctx} 

181 self.push(cont) 

182 self._names.append((node.id, node)) 

183 return self.generic_visit_args(node, cont) 

184 

185 def visit_Expr(self, node): # pylint: disable=C0116 

186 cont = { 

187 "indent": self._indent, 

188 "type": "Expr", 

189 "str": '', 

190 "node": node, 

191 "value": node.value} 

192 self.push(cont) 

193 return self.generic_visit_args(node, cont) 

194 

195 def visit_alias(self, node): # pylint: disable=C0116 

196 cont = { 

197 "indent": self._indent, 

198 "type": "alias", 

199 "str": "", 

200 "node": node, 

201 "name": node.name, 

202 "asname": node.asname} 

203 self.push(cont) 

204 self._alias.append((node.name, node.asname, node)) 

205 return self.generic_visit_args(node, cont) 

206 

207 def visit_Module(self, node): # pylint: disable=C0116 

208 cont = { 

209 "indent": self._indent, 

210 "type": "Module", 

211 "str": "", 

212 "body": node.body, 

213 "node": node} 

214 self.push(cont) 

215 return self.generic_visit_args(node, cont) 

216 

217 def visit_Import(self, node): # pylint: disable=C0116 

218 cont = { 

219 "indent": self._indent, 

220 "type": "Import", 

221 "str": "", 

222 "names": node.names, 

223 "node": node} 

224 self.push(cont) 

225 for name in node.names: 

226 self._imports.append((name.name, name.asname, node)) 

227 return self.generic_visit_args(node, cont) 

228 

229 def visit_ImportFrom(self, node): # pylint: disable=C0116 

230 cont = { 

231 "indent": self._indent, 

232 "type": "ImportFrom", 

233 "str": "", 

234 "module": node.module, 

235 "names": node.names, 

236 "node": node} 

237 self.push(cont) 

238 for name in node.names: 

239 self._imports.append((name.name, name.asname, node.module, node)) 

240 return self.generic_visit_args(node, cont) 

241 

242 def visit_ClassDef(self, node): # pylint: disable=C0116 

243 cont = { 

244 "indent": self._indent, 

245 "type": "ClassDef", 

246 "str": "", 

247 "name": node.name, 

248 "body": node.body, 

249 "node": node} 

250 self.push(cont) 

251 return self.generic_visit_args(node, cont) 

252 

253 def visit_FunctionDef(self, node): # pylint: disable=C0116 

254 cont = {"indent": self._indent, "type": "FunctionDef", "str": node.name, "name": node.name, "body": node.body, 

255 "node": node, "returns": node.returns} 

256 self.push(cont) 

257 return self.generic_visit_args(node, cont) 

258 

259 def visit_arguments(self, node): # pylint: disable=C0116 

260 cont = {"indent": self._indent, "type": "arguments", "str": "", 

261 "node": node, "args": node.args} 

262 self.push(cont) 

263 return self.generic_visit_args(node, cont) 

264 

265 def visit_arg(self, node): # pylint: disable=C0116 

266 cont = {"indent": self._indent, "type": "arg", "str": node.arg, 

267 "node": node, 

268 "arg": node.arg, "annotation": node.annotation} 

269 self.push(cont) 

270 self._args.append((node.arg, node)) 

271 return self.generic_visit_args(node, cont) 

272 

273 def visit_Assign(self, node): # pylint: disable=C0116 

274 cont = {"indent": self._indent, "type": "Assign", "str": "", "node": node, 

275 "targets": node.targets, "value": node.value} 

276 self.push(cont) 

277 for t in node.targets: 

278 if hasattr(t, 'id'): 

279 self._assign.append((t.id, node)) 

280 else: 

281 self._assign.append((id(t), node)) 

282 return self.generic_visit_args(node, cont) 

283 

284 def visit_Store(self, node): # pylint: disable=C0116 

285 #cont = { "indent":self._indent, "type": "Store", "str": "" } 

286 # self.push(cont) 

287 cont = {} 

288 return self.generic_visit_args(node, cont) 

289 

290 def visit_Call(self, node): # pylint: disable=C0116 

291 if "attr" in node.func.__dict__: 

292 cont = {"indent": self._indent, "type": "Call", "str": node.func.attr, 

293 "node": node, "func": node.func} 

294 elif "id" in node.func.__dict__: 

295 cont = {"indent": self._indent, "type": "Call", "str": node.func.id, 

296 "node": node, "func": node.func} 

297 else: 

298 cont = {"indent": self._indent, "type": "Call", "str": "", # pragma: no cover 

299 "node": node, "func": node.func} 

300 self.push(cont) 

301 if cont['str'] == 'fit': 

302 self._fits.append(cont) 

303 return self.generic_visit_args(node, cont) 

304 

305 def visit_Attribute(self, node): # pylint: disable=C0116 

306 cont = {"indent": self._indent, "type": "Attribute", "str": node.attr, 

307 "node": node, "value": node.value, "ctx": node.ctx, "attr": node.attr} 

308 self.push(cont) 

309 # last = len(self._rows) 

310 res = self.generic_visit_args(node, cont) 

311 

312 if len(cont["children"]) > 0: 

313 fir = cont["children"][0] 

314 if 'type' in fir and fir["type"] == "Name": 

315 parent = fir["node"].id 

316 cont["str"] = "{0}.{1}".format(parent, cont["str"]) 

317 cont["children"][0]["remove"] = True 

318 return res 

319 

320 def visit_Load(self, node): # pylint: disable=C0116 

321 cont = {} 

322 return self.generic_visit_args(node, cont) 

323 

324 def visit_keyword(self, node): # pylint: disable=C0116 

325 cont = {"indent": self._indent, "type": "keyword", "str": "{0}".format(node.arg), 

326 "node": node, "arg": node.arg, "value": node.value} 

327 self.push(cont) 

328 return self.generic_visit_args(node, cont) 

329 

330 def visit_BinOp(self, node): # pylint: disable=C0116 

331 cont = { 

332 "indent": self._indent, 

333 "type": "BinOp", 

334 "str": "", 

335 "node": node} 

336 self.push(cont) 

337 return self.generic_visit_args(node, cont) 

338 

339 def visit_UnaryOp(self, node): # pylint: disable=C0116 

340 cont = { 

341 "indent": self._indent, 

342 "type": "UnaryOp", 

343 "str": "", 

344 "node": node} 

345 self.push(cont) 

346 return self.generic_visit_args(node, cont) 

347 

348 def visit_Not(self, node): # pylint: disable=C0116 

349 cont = { 

350 "indent": self._indent, 

351 "type": "Not", 

352 "str": "", 

353 "node": node} 

354 self.push(cont) 

355 return self.generic_visit_args(node, cont) 

356 

357 def visit_Invert(self, node): # pylint: disable=C0116 

358 cont = { 

359 "indent": self._indent, 

360 "type": "Invert", 

361 "str": "", 

362 "node": node} 

363 self.push(cont) 

364 return self.generic_visit_args(node, cont) 

365 

366 def visit_BoolOp(self, node): # pylint: disable=C0116 

367 cont = { 

368 "indent": self._indent, 

369 "type": "BoolOp", 

370 "str": "", 

371 "node": node} 

372 self.push(cont) 

373 return self.generic_visit_args(node, cont) 

374 

375 def visit_Mult(self, node): # pylint: disable=C0116 

376 cont = { 

377 "indent": self._indent, 

378 "type": "Mult", 

379 "str": "", 

380 "node": node} 

381 self.push(cont) 

382 return self.generic_visit_args(node, cont) 

383 

384 def visit_Div(self, node): # pylint: disable=C0116 

385 cont = { 

386 "indent": self._indent, 

387 "type": "Div", 

388 "str": "", 

389 "node": node} 

390 self.push(cont) 

391 return self.generic_visit_args(node, cont) 

392 

393 def visit_FloorDiv(self, node): # pylint: disable=C0116 

394 cont = { 

395 "indent": self._indent, 

396 "type": "FloorDiv", 

397 "str": "", 

398 "node": node} 

399 self.push(cont) 

400 return self.generic_visit_args(node, cont) 

401 

402 def visit_Add(self, node): # pylint: disable=C0116 

403 cont = { 

404 "indent": self._indent, 

405 "type": "Add", 

406 "str": "", 

407 "node": node} 

408 self.push(cont) 

409 return self.generic_visit_args(node, cont) 

410 

411 def visit_Pow(self, node): # pylint: disable=C0116 

412 cont = { 

413 "indent": self._indent, 

414 "type": "Pow", 

415 "str": "", 

416 "node": node} 

417 self.push(cont) 

418 return self.generic_visit_args(node, cont) 

419 

420 def visit_In(self, node): # pylint: disable=C0116 

421 cont = { 

422 "indent": self._indent, 

423 "type": "In", 

424 "str": "", 

425 "node": node} 

426 self.push(cont) 

427 return self.generic_visit_args(node, cont) 

428 

429 def visit_AugAssign(self, node): # pylint: disable=C0116 

430 cont = { 

431 "indent": self._indent, 

432 "type": "AugAssign", 

433 "str": "", 

434 "node": node} 

435 self.push(cont) 

436 return self.generic_visit_args(node, cont) 

437 

438 def visit_Eq(self, node): # pylint: disable=C0116 

439 cont = { 

440 "indent": self._indent, 

441 "type": "Eq", 

442 "str": "", 

443 "node": node} 

444 self.push(cont) 

445 return self.generic_visit_args(node, cont) 

446 

447 def visit_IsNot(self, node): # pylint: disable=C0116 

448 cont = { 

449 "indent": self._indent, 

450 "type": "IsNot", 

451 "str": "", 

452 "node": node} 

453 self.push(cont) 

454 return self.generic_visit_args(node, cont) 

455 

456 def visit_Is(self, node): # pylint: disable=C0116 

457 cont = { 

458 "indent": self._indent, 

459 "type": "Is", 

460 "str": "", 

461 "node": node} 

462 self.push(cont) 

463 return self.generic_visit_args(node, cont) 

464 

465 def visit_And(self, node): # pylint: disable=C0116 

466 cont = { 

467 "indent": self._indent, 

468 "type": "And", 

469 "str": "", 

470 "node": node} 

471 self.push(cont) 

472 return self.generic_visit_args(node, cont) 

473 

474 def visit_BitAnd(self, node): # pylint: disable=C0116 

475 cont = { 

476 "indent": self._indent, 

477 "type": "BitAnd", 

478 "str": "", 

479 "node": node} 

480 self.push(cont) 

481 return self.generic_visit_args(node, cont) 

482 

483 def visit_Or(self, node): # pylint: disable=C0116 

484 cont = { 

485 "indent": self._indent, 

486 "type": "Or", 

487 "str": "", 

488 "node": node} 

489 self.push(cont) 

490 return self.generic_visit_args(node, cont) 

491 

492 def visit_NotEq(self, node): # pylint: disable=C0116 

493 cont = { 

494 "indent": self._indent, 

495 "type": "NotEq", 

496 "str": "", 

497 "node": node} 

498 self.push(cont) 

499 return self.generic_visit_args(node, cont) 

500 

501 def visit_Mod(self, node): # pylint: disable=C0116 

502 cont = { 

503 "indent": self._indent, 

504 "type": "Mod", 

505 "str": "", 

506 "node": node} 

507 self.push(cont) 

508 return self.generic_visit_args(node, cont) 

509 

510 def visit_Sub(self, node): # pylint: disable=C0116 

511 cont = { 

512 "indent": self._indent, 

513 "type": "Sub", 

514 "str": "", 

515 "node": node} 

516 self.push(cont) 

517 return self.generic_visit_args(node, cont) 

518 

519 def visit_USub(self, node): # pylint: disable=C0116 

520 cont = { 

521 "indent": self._indent, 

522 "type": "USub", 

523 "str": "", 

524 "node": node} 

525 self.push(cont) 

526 return self.generic_visit_args(node, cont) 

527 

528 def visit_Compare(self, node): # pylint: disable=C0116 

529 cont = { 

530 "indent": self._indent, 

531 "type": "Compare", 

532 "str": "", 

533 "node": node} 

534 self.push(cont) 

535 return self.generic_visit_args(node, cont) 

536 

537 def visit_Gt(self, node): # pylint: disable=C0116 

538 cont = {"indent": self._indent, "type": "Gt", "str": "", "node": node} 

539 self.push(cont) 

540 return self.generic_visit_args(node, cont) 

541 

542 def visit_GtE(self, node): # pylint: disable=C0116 

543 cont = {"indent": self._indent, "type": "GtE", "str": "", "node": node} 

544 self.push(cont) 

545 return self.generic_visit_args(node, cont) 

546 

547 def visit_Lt(self, node): # pylint: disable=C0116 

548 cont = {"indent": self._indent, "type": "Lt", "str": "", "node": node} 

549 self.push(cont) 

550 return self.generic_visit_args(node, cont) 

551 

552 def visit_Num(self, node): # pylint: disable=C0116 

553 cont = { 

554 "indent": self._indent, 

555 "type": "Num", 

556 "node": node, 

557 "str": "{0}".format( 

558 node.n), 

559 'n': node.n} 

560 self.push(cont) 

561 return self.generic_visit_args(node, cont) 

562 

563 def visit_Return(self, node): # pylint: disable=C0116 

564 cont = {"indent": self._indent, "type": "Return", "node": node, "str": "", 

565 'value': node.value} 

566 self.push(cont) 

567 return self.generic_visit_args(node, cont) 

568 

569 def visit_List(self, node): # pylint: disable=C0116 

570 cont = { 

571 "indent": self._indent, 

572 "type": "List", 

573 "node": node} 

574 self.push(cont) 

575 return self.generic_visit_args(node, cont) 

576 

577 def visit_ListComp(self, node): # pylint: disable=C0116 

578 cont = { 

579 "indent": self._indent, 

580 "type": "ListComp", 

581 "node": node} 

582 self.push(cont) 

583 return self.generic_visit_args(node, cont) 

584 

585 def visit_comprehension(self, node): # pylint: disable=C0116 

586 cont = { 

587 "indent": self._indent, 

588 "type": "comprehension", 

589 "node": node} 

590 self.push(cont) 

591 return self.generic_visit_args(node, cont) 

592 

593 def visit_Dict(self, node): # pylint: disable=C0116 

594 cont = { 

595 "indent": self._indent, 

596 "type": "Dict", 

597 "node": node} 

598 self.push(cont) 

599 return self.generic_visit_args(node, cont) 

600 

601 def visit_Tuple(self, node): # pylint: disable=C0116 

602 cont = { 

603 "indent": self._indent, 

604 "type": "Tuple", 

605 "node": node} 

606 self.push(cont) 

607 return self.generic_visit_args(node, cont) 

608 

609 def visit_NameConstant(self, node): # pylint: disable=C0116 

610 cont = { 

611 "indent": self._indent, 

612 "type": "NameConstant", 

613 "node": node} 

614 self.push(cont) 

615 return self.generic_visit_args(node, cont) 

616 

617 def visit_(self, node): # pylint: disable=C0116 

618 raise RuntimeError( # pragma: no cover 

619 "This node is not handled: {}".format(node)) 

620 

621 def visit_Subscript(self, node): # pylint: disable=C0116 

622 cont = { 

623 "indent": self._indent, 

624 "str": "Subscript", 

625 "node": node} 

626 self.push(cont) 

627 return self.generic_visit_args(node, cont) 

628 

629 def visit_ExtSlice(self, node): # pylint: disable=C0116 

630 cont = { 

631 "indent": self._indent, 

632 "str": "ExtSlice", 

633 "node": node} 

634 self.push(cont) 

635 return self.generic_visit_args(node, cont) 

636 

637 def visit_Slice(self, node): # pylint: disable=C0116 

638 cont = { 

639 "indent": self._indent, 

640 "str": "Slice", 

641 "node": node} 

642 self.push(cont) 

643 return self.generic_visit_args(node, cont) 

644 

645 def visit_Index(self, node): # pylint: disable=C0116 

646 cont = { 

647 "indent": self._indent, 

648 "str": "Index", 

649 "node": node} 

650 self.push(cont) 

651 return self.generic_visit_args(node, cont) 

652 

653 def visit_If(self, node): # pylint: disable=C0116 

654 cont = { 

655 "indent": self._indent, 

656 "str": "If", 

657 "node": node} 

658 self.push(cont) 

659 return self.generic_visit_args(node, cont) 

660 

661 def visit_IfExp(self, node): # pylint: disable=C0116 

662 cont = { 

663 "indent": self._indent, 

664 "str": "IfExp", 

665 "node": node} 

666 self.push(cont) 

667 return self.generic_visit_args(node, cont) 

668 

669 def visit_Lambda(self, node): # pylint: disable=C0116 

670 cont = { 

671 "indent": self._indent, 

672 "str": "Lambda", 

673 "node": node} 

674 self.push(cont) 

675 return self.generic_visit_args(node, cont) 

676 

677 def visit_GeneratorExp(self, node): # pylint: disable=C0116 

678 cont = { 

679 "indent": self._indent, 

680 "str": "GeneratorExp", 

681 "node": node} 

682 self.push(cont) 

683 return self.generic_visit_args(node, cont)