Coverage for mlprodict/plotting/text_plot.py: 94%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

400 statements  

1# pylint: disable=R0912 

2""" 

3@file 

4@brief Text representations of graphs. 

5""" 

6from collections import OrderedDict 

7import numpy 

8from onnx import TensorProto, AttributeProto 

9from onnx.numpy_helper import to_array 

10from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 

11from ..tools.graphs import onnx2bigraph 

12from ..onnx_tools.onnx2py_helper import _var_as_dict 

13 

14 

15def onnx_text_plot(model_onnx, recursive=False, graph_type='basic', 

16 grid=5, distance=5): 

17 """ 

18 Uses @see fn onnx2bigraph to convert the ONNX graph 

19 into text. 

20 

21 :param model_onnx: onnx representation 

22 :param recursive: @see fn onnx2bigraph 

23 :param graph_type: @see fn onnx2bigraph 

24 :param grid: @see me display_structure 

25 :param distance: @see fn display_structure 

26 :return: text 

27 

28 .. runpython:: 

29 :showcode: 

30 :warningout: DeprecationWarning 

31 

32 import numpy 

33 from mlprodict.onnx_conv import to_onnx 

34 from mlprodict import __max_supported_opset__ as opv 

35 from mlprodict.plotting.plotting import onnx_text_plot 

36 from mlprodict.npy.xop import loadop 

37 

38 OnnxAdd, OnnxSub = loadop('Add', 'Sub') 

39 

40 idi = numpy.identity(2).astype(numpy.float32) 

41 A = OnnxAdd('X', idi, op_version=opv) 

42 B = OnnxSub(A, 'W', output_names=['Y'], op_version=opv) 

43 onx = B.to_onnx({'X': idi, 'W': idi}) 

44 print(onnx_text_plot(onx)) 

45 """ 

46 bigraph = onnx2bigraph(model_onnx) 

47 graph = bigraph.display_structure() 

48 return graph.to_text() 

49 

50 

51def onnx_text_plot_tree(node): 

52 """ 

53 Gives a textual representation of a tree ensemble. 

54 

55 :param node: `TreeEnsemble*` 

56 :return: text 

57 

58 .. runpython:: 

59 :showcode: 

60 :warningout: DeprecationWarning 

61 

62 import numpy 

63 from sklearn.datasets import load_iris 

64 from sklearn.tree import DecisionTreeRegressor 

65 from mlprodict.onnx_conv import to_onnx 

66 from mlprodict.plotting.plotting import onnx_text_plot_tree 

67 

68 iris = load_iris() 

69 X, y = iris.data.astype(numpy.float32), iris.target 

70 clr = DecisionTreeRegressor(max_depth=3) 

71 clr.fit(X, y) 

72 onx = to_onnx(clr, X) 

73 res = onnx_text_plot_tree(onx.graph.node[0]) 

74 print(res) 

75 """ 

76 def rule(r): 

77 if r == b'BRANCH_LEQ': 

78 return '<=' 

79 if r == b'BRANCH_LT': # pragma: no cover 

80 return '<' 

81 if r == b'BRANCH_GEQ': # pragma: no cover 

82 return '>=' 

83 if r == b'BRANCH_GT': # pragma: no cover 

84 return '>' 

85 if r == b'BRANCH_EQ': # pragma: no cover 

86 return '==' 

87 if r == b'BRANCH_NEQ': # pragma: no cover 

88 return '!=' 

89 raise ValueError( # pragma: no cover 

90 "Unexpected rule %r." % rule) 

91 

92 class Node: 

93 "Node representation." 

94 

95 def __init__(self, i, atts): 

96 self.nodes_hitrates = None 

97 self.nodes_missing_value_tracks_true = None 

98 for k, v in atts.items(): 

99 if k.startswith('nodes'): 

100 setattr(self, k, v[i]) 

101 self.depth = 0 

102 self.true_false = '' 

103 

104 def process_node(self): 

105 "node to string" 

106 if self.nodes_modes == b'LEAF': # pylint: disable=E1101 

107 text = "%s y=%r f=%r i=%r" % ( 

108 self.true_false, 

109 self.target_weights, self.target_ids, # pylint: disable=E1101 

110 self.target_nodeids) # pylint: disable=E1101 

111 else: 

112 text = "%s X%d %s %r" % ( 

113 self.true_false, 

114 self.nodes_featureids, # pylint: disable=E1101 

115 rule(self.nodes_modes), # pylint: disable=E1101 

116 self.nodes_values) # pylint: disable=E1101 

117 if self.nodes_hitrates and self.nodes_hitrates != 1: 

118 text += " hi=%r" % self.nodes_hitrates 

119 if self.nodes_missing_value_tracks_true: 

120 text += " miss=%r" % ( 

121 self.nodes_missing_value_tracks_true) 

122 return "%s%s" % (" " * self.depth, text) 

123 

124 def process_tree(atts, treeid): 

125 "tree to string" 

126 rows = ['treeid=%r' % treeid] 

127 if 'base_values' in atts: 

128 rows.append('base_value=%r' % atts['base_values'][treeid]) 

129 

130 short = {} 

131 for prefix in ['nodes', 'target', 'class']: 

132 if ('%s_treeids' % prefix) not in atts: 

133 continue 

134 idx = [i for i in range(len(atts['%s_treeids' % prefix])) 

135 if atts['%s_treeids' % prefix][i] == treeid] 

136 for k, v in atts.items(): 

137 if k.startswith(prefix): 

138 short[k] = [v[i] for i in idx] 

139 

140 nodes = OrderedDict() 

141 for i in range(len(short['nodes_treeids'])): 

142 nodes[i] = Node(i, short) 

143 for i in range(len(short['target_treeids'])): 

144 idn = short['target_nodeids'][i] 

145 node = nodes[idn] 

146 node.target_nodeids = idn 

147 node.target_ids = short['target_ids'][i] 

148 node.target_weights = short['target_weights'][i] 

149 

150 def iterate(nodes, node, depth=0, true_false=''): 

151 node.depth = depth 

152 node.true_false = true_false 

153 yield node 

154 if node.nodes_falsenodeids > 0: 

155 for n in iterate(nodes, nodes[node.nodes_falsenodeids], 

156 depth=depth + 1, true_false='F'): 

157 yield n 

158 for n in iterate(nodes, nodes[node.nodes_truenodeids], 

159 depth=depth + 1, true_false='T'): 

160 yield n 

161 

162 for node in iterate(nodes, nodes[0]): 

163 rows.append(node.process_node()) 

164 return rows 

165 

166 if node.op_type != "TreeEnsembleRegressor": 

167 raise NotImplementedError( # pragma: no cover 

168 "Type %r cannot be displayed." % node.op_type) 

169 d = {k: v['value'] for k, v in _var_as_dict(node)['atts'].items()} 

170 atts = {} 

171 for k, v in d.items(): 

172 atts[k] = v if isinstance(v, int) else list(v) 

173 trees = list(sorted(set(atts['nodes_treeids']))) 

174 rows = ['n_targets=%r' % atts['n_targets'], 

175 'n_trees=%r' % len(trees)] 

176 for tree in trees: 

177 r = process_tree(atts, tree) 

178 rows.append('----') 

179 rows.extend(r) 

180 

181 return "\n".join(rows) 

182 

183 

184def reorder_nodes_for_display(nodes, verbose=False): 

185 """ 

186 Reorders the node with breadth first seach (BFS). 

187 

188 :param nodes: list of ONNX nodes 

189 :param verbose: dislay intermediate informations 

190 :return: reordered list of nodes 

191 """ 

192 all_outputs = set() 

193 all_inputs = set() 

194 for node in nodes: 

195 all_outputs |= set(node.output) 

196 all_inputs |= set(node.input) 

197 common = all_outputs & all_inputs 

198 dnodes = OrderedDict() 

199 successors = {} 

200 predecessors = {} 

201 for node in nodes: 

202 node_name = node.name + "#" + "|".join(node.output) 

203 dnodes[node_name] = node 

204 successors[node_name] = set() 

205 predecessors[node_name] = set() 

206 for name in node.input: 

207 predecessors[node_name].add(name) 

208 if name not in successors: 

209 successors[name] = set() 

210 successors[name].add(node_name) 

211 for name in node.output: 

212 successors[node_name].add(name) 

213 predecessors[name] = {node_name} 

214 

215 known = all_inputs - common 

216 new_nodes = [] 

217 done = set() 

218 

219 def _find_sequence(node_name, known, done): 

220 inputs = dnodes[node_name].input 

221 if any(map(lambda i: i not in known, inputs)): 

222 return [] 

223 

224 res = [node_name] 

225 while res[-1] in successors: 

226 next_names = successors[res[-1]] 

227 if res[-1] not in dnodes: 

228 next_names = set(v for v in next_names if v not in known) 

229 if len(next_names) == 1: 

230 next_name = next_names.pop() 

231 inputs = dnodes[next_name].input 

232 if any(map(lambda i: i not in known, inputs)): 

233 break 

234 res.extend(next_name) 

235 else: 

236 break 

237 else: 

238 next_names = set(v for v in next_names if v not in done) 

239 if len(next_names) == 1: 

240 next_name = next_names.pop() 

241 res.append(next_name) 

242 else: 

243 break 

244 

245 return [r for r in res if r in dnodes and r not in done] 

246 

247 while len(done) < len(nodes): 

248 # possible 

249 possibles = OrderedDict() 

250 for k, v in dnodes.items(): 

251 if k in done: 

252 continue 

253 if predecessors[k] <= known: 

254 possibles[k] = v 

255 

256 sequences = OrderedDict() 

257 for k, v in possibles.items(): 

258 if k in done: 

259 continue 

260 sequences[k] = _find_sequence(k, known, done) 

261 if verbose: 

262 print("[reorder_nodes_for_display] sequence(%s)=%s" % ( 

263 k, ",".join(sequences[k]))) 

264 

265 if len(sequences) == 0: 

266 raise RuntimeError( # pragma: no cover 

267 "Unexpected empty sequences (len(possibles)=%d, " 

268 "len(done)=%d, len(nodes)=%d). This is usually due to " 

269 "a name used both as result name and node node." 

270 "" % (len(possibles), len(done), len(nodes))) 

271 

272 # find the best sequence 

273 best = None 

274 for k, v in sequences.items(): 

275 if best is None or len(v) > len(sequences[best]): 

276 # if the sequence of successors is longer 

277 best = k 

278 elif len(v) == len(sequences[best]): 

279 if len(new_nodes) > 0: 

280 # then choose the next successor sharing input with 

281 # previous output 

282 so = set(new_nodes[-1].output) 

283 first1 = dnodes[sequences[best][0]] 

284 first2 = dnodes[v[0]] 

285 if len(set(first1.input) & so) < len(set(first2.input) & so): 

286 best = k 

287 else: 

288 first1 = dnodes[sequences[best][0]] 

289 first2 = dnodes[v[0]] 

290 if first1.op_type > first2.op_type: 

291 best = k 

292 elif (first1.op_type == first2.op_type and 

293 first1.name > first2.name): 

294 best = k 

295 

296 if best is None: 

297 raise RuntimeError( # pragma: no cover 

298 "Wrong implementation (len(sequence)=%d)." % len(sequences)) 

299 if verbose: 

300 print("[reorder_nodes_for_display] BEST: sequence(%s)=%s" % ( 

301 best, ",".join(sequences[best]))) 

302 

303 # process the sequence 

304 for k in sequences[best]: 

305 v = dnodes[k] 

306 new_nodes.append(v) 

307 done.add(k) 

308 known |= set(v.output) 

309 

310 if len(new_nodes) != len(nodes): 

311 raise RuntimeError( # pragma: no cover 

312 "The returned new nodes are different. " 

313 "len(nodes=%d != %d=len(new_nodes). done=\n%r" 

314 "\n%s\n----------\n%s" % ( 

315 len(nodes), len(new_nodes), done, 

316 "\n".join("%d - %s - %s - %s" % ( 

317 (n.name + "".join(n.output)) in done, 

318 n.op_type, n.name, n.name + "".join(n.output)) 

319 for n in nodes), 

320 "\n".join("%d - %s - %s - %s" % ( 

321 (n.name + "".join(n.output)) in done, 

322 n.op_type, n.name, n.name + "".join(n.output)) 

323 for n in new_nodes))) 

324 return new_nodes 

325 

326 

327def _get_type(obj0): 

328 obj = obj0 

329 if hasattr(obj, 'data_type'): 

330 if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101 

331 hasattr(obj, 'float_data')): 

332 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.FLOAT] # pylint: disable=E1101 

333 if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101 

334 hasattr(obj, 'double_data')): 

335 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.DOUBLE] # pylint: disable=E1101 

336 if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101 

337 hasattr(obj, 'int64_data')): 

338 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT64] # pylint: disable=E1101 

339 if (obj.data_type == TensorProto.INT32 and # pylint: disable=E1101 

340 hasattr(obj, 'int32_data')): 

341 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT32] # pylint: disable=E1101 

342 raise RuntimeError( # pragma: no cover 

343 "Unable to guess type from %r." % obj0) 

344 if hasattr(obj, 'type'): 

345 obj = obj.type 

346 if hasattr(obj, 'tensor_type'): 

347 obj = obj.tensor_type 

348 if hasattr(obj, 'elem_type'): 

349 return TENSOR_TYPE_TO_NP_TYPE.get(obj.elem_type, '?') 

350 raise RuntimeError( # pragma: no cover 

351 "Unable to guess type from %r." % obj0) 

352 

353 

354def _get_shape(obj): 

355 obj0 = obj 

356 if hasattr(obj, 'data_type'): 

357 if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101 

358 hasattr(obj, 'float_data')): 

359 return (len(obj.float_data), ) 

360 if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101 

361 hasattr(obj, 'double_data')): 

362 return (len(obj.double_data), ) 

363 if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101 

364 hasattr(obj, 'int64_data')): 

365 return (len(obj.int64_data), ) 

366 if (obj.data_type == TensorProto.INT32 and # pylint: disable=E1101 

367 hasattr(obj, 'int32_data')): 

368 return (len(obj.int32_data), ) 

369 raise RuntimeError( # pragma: no cover 

370 "Unable to guess type from %r." % obj0) 

371 if hasattr(obj, 'type'): 

372 obj = obj.type 

373 if hasattr(obj, 'tensor_type'): 

374 obj = obj.tensor_type 

375 if hasattr(obj, 'shape'): 

376 obj = obj.shape 

377 dims = [] 

378 for d in obj.dim: 

379 if hasattr(d, 'dim_value'): 

380 dims.append(d.dim_value) 

381 else: 

382 dims.append(None) 

383 return tuple(dims) 

384 raise RuntimeError( # pragma: no cover 

385 "Unable to guess type from %r." % obj0) 

386 

387 

388def onnx_simple_text_plot(model, verbose=False, att_display=None, 

389 add_links=False, recursive=False, functions=True): 

390 """ 

391 Displays an ONNX graph into text. 

392 

393 :param model: ONNX graph 

394 :param verbose: display debugging information 

395 :param att_display: list of attributes to display, if None, 

396 a default list if used 

397 :param add_links: displays links of the right side 

398 :param recursive: display subgraphs as well 

399 :param functions: display functions as well 

400 :return: str 

401 

402 An ONNX graph is printed the following way: 

403 

404 .. runpython:: 

405 :showcode: 

406 :warningout: DeprecationWarning 

407 

408 import numpy 

409 from sklearn.cluster import KMeans 

410 from mlprodict.plotting.plotting import onnx_simple_text_plot 

411 from mlprodict.onnx_conv import to_onnx 

412 

413 x = numpy.random.randn(10, 3) 

414 y = numpy.random.randn(10) 

415 model = KMeans(3) 

416 model.fit(x, y) 

417 onx = to_onnx(model, x.astype(numpy.float32), 

418 target_opset=15) 

419 text = onnx_simple_text_plot(onx, verbose=False) 

420 print(text) 

421 

422 The same graphs with links. 

423 

424 .. runpython:: 

425 :showcode: 

426 :warningout: DeprecationWarning 

427 

428 import numpy 

429 from sklearn.cluster import KMeans 

430 from mlprodict.plotting.plotting import onnx_simple_text_plot 

431 from mlprodict.onnx_conv import to_onnx 

432 

433 x = numpy.random.randn(10, 3) 

434 y = numpy.random.randn(10) 

435 model = KMeans(3) 

436 model.fit(x, y) 

437 onx = to_onnx(model, x.astype(numpy.float32), 

438 target_opset=15) 

439 text = onnx_simple_text_plot(onx, verbose=False, add_links=True) 

440 print(text) 

441 

442 Visually, it looks like the following: 

443 

444 .. gdot:: 

445 :script: DOT-SECTION 

446 

447 import numpy 

448 from sklearn.cluster import KMeans 

449 from mlprodict.onnxrt import OnnxInference 

450 from mlprodict.onnx_conv import to_onnx 

451 

452 x = numpy.random.randn(10, 3) 

453 y = numpy.random.randn(10) 

454 model = KMeans(3) 

455 model.fit(x, y) 

456 model_onnx = to_onnx(model, x.astype(numpy.float32), 

457 target_opset=15) 

458 oinf = OnnxInference(model_onnx, inplace=False) 

459 

460 print("DOT-SECTION", oinf.to_dot()) 

461 """ 

462 if att_display is None: 

463 att_display = [ 

464 'activations', 

465 'align_corners', 

466 'allowzero', 

467 'alpha', 

468 'auto_pad', 

469 'axis', 

470 'axes', 

471 'batch_axis', 

472 'batch_dims', 

473 'beta', 

474 'bias', 

475 'blocksize', 

476 'case_change_action', 

477 'ceil_mode', 

478 'center_point_box', 

479 'clip', 

480 'coordinate_transformation_mode', 

481 'count_include_pad', 

482 'cubic_coeff_a', 

483 'decay_factor', 

484 'detect_negative', 

485 'detect_positive', 

486 'dilation', 

487 'dilations', 

488 'direction', 

489 'dtype', 

490 'end', 

491 'epsilon', 

492 'equation', 

493 'exclusive', 

494 'exclude_outside', 

495 'extrapolation_value', 

496 'fmod', 

497 'gamma', 

498 'group', 

499 'hidden_size', 

500 'high', 

501 'ignore_index', 

502 'input_forget', 

503 'is_case_sensitive', 

504 'k', 

505 'keepdims', 

506 'kernel_shape', 

507 'lambd', 

508 'largest', 

509 'layout', 

510 'linear_before_reset', 

511 'locale', 

512 'low', 

513 'max_gram_length', 

514 'max_skip_count', 

515 'mean', 

516 'min_gram_length', 

517 'mode', 

518 'momentum', 

519 'nearest_mode', 

520 'ngram_counts', 

521 'ngram_indexes', 

522 'noop_with_empty_axes', 

523 'norm_coefficient', 

524 'norm_coefficient_post', 

525 'num_scan_inputs', 

526 'output_height', 

527 'output_padding', 

528 'output_shape', 

529 'output_width', 

530 'p', 

531 'padding_mode', 

532 'pads', 

533 'perm', 

534 'pooled_shape', 

535 'reduction', 

536 'reverse', 

537 'sample_size', 

538 'sampling_ratio', 

539 'scale', 

540 'scan_input_axes', 

541 'scan_input_directions', 

542 'scan_output_axes', 

543 'scan_output_directions', 

544 'seed', 

545 'select_last_index', 

546 'size', 

547 'sorted', 

548 'spatial_scale', 

549 'start', 

550 'storage_order', 

551 'strides', 

552 'time_axis', 

553 'to', 

554 'training_mode', 

555 'transA', 

556 'transB', 

557 'type', 

558 'upper', 

559 'xs', 

560 'y', 

561 'zs', 

562 ] 

563 

564 def str_node(indent, node): 

565 atts = [] 

566 if hasattr(node, 'attribute'): 

567 for att in node.attribute: 

568 if att.name in att_display: 

569 if att.type == AttributeProto.INT: # pylint: disable=E1101 

570 atts.append("%s=%d" % (att.name, att.i)) 

571 elif att.type == AttributeProto.FLOAT: # pylint: disable=E1101 

572 atts.append("%s=%1.2f" % (att.name, att.f)) 

573 elif att.type == AttributeProto.INTS: # pylint: disable=E1101 

574 atts.append("%s=%s" % (att.name, str( 

575 list(att.ints)).replace(" ", ""))) 

576 inputs = list(node.input) 

577 if len(atts) > 0: 

578 inputs.extend(atts) 

579 if node.domain in ('', 'ai.onnx.ml'): 

580 domain = '' 

581 else: 

582 domain = '[%s]' % node.domain 

583 return "%s%s%s(%s) -> %s" % ( 

584 " " * indent, node.op_type, domain, 

585 ", ".join(inputs), ", ".join(node.output)) 

586 

587 rows = [] 

588 if hasattr(model, 'opset_import'): 

589 for opset in model.opset_import: 

590 rows.append("opset: domain=%r version=%r" % ( 

591 opset.domain, opset.version)) 

592 if hasattr(model, 'graph'): 

593 main_model = model 

594 model = model.graph 

595 else: 

596 main_model = None 

597 

598 # inputs 

599 line_name_new = {} 

600 line_name_in = {} 

601 for inp in model.input: 

602 if isinstance(inp, str): 

603 rows.append("input: %r" % inp) 

604 else: 

605 line_name_new[inp.name] = len(rows) 

606 rows.append("input: name=%r type=%r shape=%r" % ( 

607 inp.name, _get_type(inp), _get_shape(inp))) 

608 # initializer 

609 if hasattr(model, 'initializer'): 

610 for init in model.initializer: 

611 if numpy.prod(_get_shape(init)) < 5: 

612 content = " -- %r" % to_array(init).ravel() 

613 else: 

614 content = "" 

615 line_name_new[init.name] = len(rows) 

616 rows.append("init: name=%r type=%r shape=%r%s" % ( 

617 init.name, _get_type(init), _get_shape(init), content)) 

618 

619 # successors, predecessors 

620 successors = {} 

621 predecessors = {} 

622 subgraphs = [] 

623 for node in model.node: 

624 node_name = node.name + "#" + "|".join(node.output) 

625 successors[node_name] = [] 

626 predecessors[node_name] = [] 

627 for name in node.input: 

628 predecessors[node_name].append(name) 

629 if name not in successors: 

630 successors[name] = [] 

631 successors[name].append(node_name) 

632 for name in node.output: 

633 successors[node_name].append(name) 

634 predecessors[name] = [node_name] 

635 if recursive and node.op_type in {'If', 'Scan', 'Loop'}: 

636 for att in node.attribute: 

637 if att.name not in {'body', 'else_branch', 'then_branch'}: 

638 continue 

639 subgraphs.append((node, att.name, att.g)) 

640 

641 # walk through nodes 

642 init_names = set() 

643 indents = {} 

644 for inp in model.input: 

645 if isinstance(inp, str): 

646 indents[inp] = 0 

647 init_names.add(inp) 

648 else: 

649 indents[inp.name] = 0 

650 init_names.add(inp.name) 

651 if hasattr(model, 'initializer'): 

652 for init in model.initializer: 

653 indents[init.name] = 0 

654 init_names.add(init.name) 

655 

656 nodes = reorder_nodes_for_display(model.node, verbose=verbose) 

657 

658 previous_indent = None 

659 previous_out = None 

660 previous_in = None 

661 for node in nodes: 

662 add_break = False 

663 name = node.name + "#" + "|".join(node.output) 

664 if name in indents: 

665 indent = indents[name] 

666 if previous_indent is not None and indent < previous_indent: 

667 if verbose: 

668 print("[onnx_simple_text_plot] break1 %s" % node.op_type) 

669 add_break = True 

670 elif previous_in is not None and set(node.input) == previous_in: 

671 indent = previous_indent 

672 else: 

673 inds = [indents.get(i, 0) 

674 for i in node.input if i not in init_names] 

675 if len(inds) == 0: 

676 indent = 0 

677 else: 

678 mi = min(inds) 

679 indent = mi 

680 if previous_indent is not None and indent < previous_indent: 

681 if verbose: 

682 print( # pragma: no cover 

683 "[onnx_simple_text_plot] break2 %s" % 

684 node.op_type) 

685 add_break = True 

686 if not add_break and previous_out is not None: 

687 if len(set(node.input) & previous_out) == 0: 

688 if verbose: 

689 print("[onnx_simple_text_plot] break3 %s" % 

690 node.op_type) 

691 add_break = True 

692 indent = 0 

693 

694 if add_break and verbose: 

695 print("[onnx_simple_text_plot] add break") 

696 for n in node.input: 

697 if n in line_name_in: 

698 line_name_in[n].append(len(rows)) 

699 else: 

700 line_name_in[n] = [len(rows)] 

701 for n in node.output: 

702 line_name_new[n] = len(rows) 

703 rows.append(str_node(indent, node)) 

704 indents[name] = indent 

705 

706 for i, o in enumerate(node.output): 

707 indents[o] = indent + 1 

708 

709 previous_indent = indents[name] 

710 previous_out = set(node.output) 

711 previous_in = set(node.input) 

712 

713 # outputs 

714 for out in model.output: 

715 if isinstance(out, str): 

716 if out in line_name_in: 

717 line_name_in[out].append(len(rows)) 

718 else: 

719 line_name_in[out] = [len(rows)] 

720 rows.append("output: name=%r type=%s shape=%s" % ( 

721 out, '?', '?')) 

722 else: 

723 if out.name in line_name_in: 

724 line_name_in[out.name].append(len(rows)) 

725 else: 

726 line_name_in[out.name] = [len(rows)] 

727 rows.append("output: name=%r type=%r shape=%r" % ( 

728 out.name, _get_type(out), _get_shape(out))) 

729 

730 if add_links: 

731 

732 def _mark_link(rows, lengths, r1, r2, d): 

733 maxl = max(lengths[r1], lengths[r2]) + d * 2 

734 maxl = max(maxl, max(len(rows[r]) for r in range(r1, r2 + 1))) + 2 

735 

736 if rows[r1][-1] == '|': 

737 p1, p2 = rows[r1][:lengths[r1] + 2], rows[r1][lengths[r1] + 2:] 

738 rows[r1] = p1 + p2.replace(' ', '-') 

739 rows[r1] += ("-" * (maxl - len(rows[r1]) - 1)) + "+" 

740 

741 if rows[r2][-1] == " ": 

742 rows[r2] += "<" 

743 elif rows[r2][-1] == '|': 

744 if "<" not in rows[r2]: 

745 p = lengths[r2] 

746 rows[r2] = rows[r2][:p] + '<' + rows[r2][p + 1:] 

747 p1, p2 = rows[r2][:lengths[r2] + 2], rows[r2][lengths[r2] + 2:] 

748 rows[r2] = p1 + p2.replace(' ', '-') 

749 rows[r2] += ("-" * (maxl - len(rows[r2]) - 1)) + "+" 

750 

751 for r in range(r1 + 1, r2): 

752 if len(rows[r]) < maxl: 

753 rows[r] += " " * (maxl - len(rows[r]) - 1) 

754 rows[r] += "|" 

755 

756 diffs = [] 

757 for n, r1 in line_name_new.items(): 

758 if n not in line_name_in: 

759 continue 

760 r2s = line_name_in[n] 

761 for r2 in r2s: 

762 if r1 >= r2: 

763 continue 

764 diffs.append((r2 - r1, (n, r1, r2))) 

765 diffs.sort() 

766 for i in range(len(rows)): # pylint: disable=C0200 

767 rows[i] += " " 

768 lengths = [len(r) for r in rows] 

769 

770 for d, (n, r1, r2) in diffs: 

771 if d == 1 and len(line_name_in[n]) == 1: 

772 # no line for link to the next node 

773 continue 

774 _mark_link(rows, lengths, r1, r2, d) 

775 

776 # subgraphs 

777 for node, name, g in subgraphs: 

778 rows.append('----- subgraph ---- %s - %s - att.%s=' % ( 

779 node.op_type, node.name, name)) 

780 res = onnx_simple_text_plot( 

781 g, verbose=verbose, att_display=att_display, 

782 add_links=add_links, recursive=recursive) 

783 rows.append(res) 

784 

785 # functions 

786 if functions and main_model is not None: 

787 for fct in main_model.functions: 

788 rows.append('----- function name=%s domain=%s' % ( 

789 fct.name, fct.domain)) 

790 res = onnx_simple_text_plot( 

791 fct, verbose=verbose, att_display=att_display, 

792 add_links=add_links, recursive=recursive, 

793 functions=False) 

794 rows.append(res) 

795 

796 return "\n".join(rows) 

797 

798 

799def onnx_text_plot_io(model, verbose=False, att_display=None): 

800 """ 

801 Displays information about input and output types. 

802 

803 :param model: ONNX graph 

804 :param verbose: display debugging information 

805 :return: str 

806 

807 An ONNX graph is printed the following way: 

808 

809 .. runpython:: 

810 :showcode: 

811 :warningout: DeprecationWarning 

812 

813 import numpy 

814 from sklearn.cluster import KMeans 

815 from mlprodict.plotting.plotting import onnx_text_plot_io 

816 from mlprodict.onnx_conv import to_onnx 

817 

818 x = numpy.random.randn(10, 3) 

819 y = numpy.random.randn(10) 

820 model = KMeans(3) 

821 model.fit(x, y) 

822 onx = to_onnx(model, x.astype(numpy.float32), 

823 target_opset=15) 

824 text = onnx_text_plot_io(onx, verbose=False) 

825 print(text) 

826 """ 

827 rows = [] 

828 if hasattr(model, 'opset_import'): 

829 for opset in model.opset_import: 

830 rows.append("opset: domain=%r version=%r" % ( 

831 opset.domain, opset.version)) 

832 if hasattr(model, 'graph'): 

833 model = model.graph 

834 

835 # inputs 

836 for inp in model.input: 

837 rows.append("input: name=%r type=%r shape=%r" % ( 

838 inp.name, _get_type(inp), _get_shape(inp))) 

839 # initializer 

840 for init in model.initializer: 

841 rows.append("init: name=%r type=%r shape=%r" % ( 

842 init.name, _get_type(init), _get_shape(init))) 

843 # outputs 

844 for out in model.output: 

845 rows.append("output: name=%r type=%r shape=%r" % ( 

846 out.name, _get_type(out), _get_shape(out))) 

847 return "\n".join(rows)