Coverage for mlprodict/plotting/text_plot.py: 94%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=R0912
2"""
3@file
4@brief Text representations of graphs.
5"""
6from collections import OrderedDict
7import numpy
8from onnx import TensorProto, AttributeProto
9from onnx.numpy_helper import to_array
10from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
11from ..tools.graphs import onnx2bigraph
12from ..onnx_tools.onnx2py_helper import _var_as_dict
15def onnx_text_plot(model_onnx, recursive=False, graph_type='basic',
16 grid=5, distance=5):
17 """
18 Uses @see fn onnx2bigraph to convert the ONNX graph
19 into text.
21 :param model_onnx: onnx representation
22 :param recursive: @see fn onnx2bigraph
23 :param graph_type: @see fn onnx2bigraph
24 :param grid: @see me display_structure
25 :param distance: @see fn display_structure
26 :return: text
28 .. runpython::
29 :showcode:
30 :warningout: DeprecationWarning
32 import numpy
33 from mlprodict.onnx_conv import to_onnx
34 from mlprodict import __max_supported_opset__ as opv
35 from mlprodict.plotting.plotting import onnx_text_plot
36 from mlprodict.npy.xop import loadop
38 OnnxAdd, OnnxSub = loadop('Add', 'Sub')
40 idi = numpy.identity(2).astype(numpy.float32)
41 A = OnnxAdd('X', idi, op_version=opv)
42 B = OnnxSub(A, 'W', output_names=['Y'], op_version=opv)
43 onx = B.to_onnx({'X': idi, 'W': idi})
44 print(onnx_text_plot(onx))
45 """
46 bigraph = onnx2bigraph(model_onnx)
47 graph = bigraph.display_structure()
48 return graph.to_text()
51def onnx_text_plot_tree(node):
52 """
53 Gives a textual representation of a tree ensemble.
55 :param node: `TreeEnsemble*`
56 :return: text
58 .. runpython::
59 :showcode:
60 :warningout: DeprecationWarning
62 import numpy
63 from sklearn.datasets import load_iris
64 from sklearn.tree import DecisionTreeRegressor
65 from mlprodict.onnx_conv import to_onnx
66 from mlprodict.plotting.plotting import onnx_text_plot_tree
68 iris = load_iris()
69 X, y = iris.data.astype(numpy.float32), iris.target
70 clr = DecisionTreeRegressor(max_depth=3)
71 clr.fit(X, y)
72 onx = to_onnx(clr, X)
73 res = onnx_text_plot_tree(onx.graph.node[0])
74 print(res)
75 """
76 def rule(r):
77 if r == b'BRANCH_LEQ':
78 return '<='
79 if r == b'BRANCH_LT': # pragma: no cover
80 return '<'
81 if r == b'BRANCH_GEQ': # pragma: no cover
82 return '>='
83 if r == b'BRANCH_GT': # pragma: no cover
84 return '>'
85 if r == b'BRANCH_EQ': # pragma: no cover
86 return '=='
87 if r == b'BRANCH_NEQ': # pragma: no cover
88 return '!='
89 raise ValueError( # pragma: no cover
90 "Unexpected rule %r." % rule)
92 class Node:
93 "Node representation."
95 def __init__(self, i, atts):
96 self.nodes_hitrates = None
97 self.nodes_missing_value_tracks_true = None
98 for k, v in atts.items():
99 if k.startswith('nodes'):
100 setattr(self, k, v[i])
101 self.depth = 0
102 self.true_false = ''
104 def process_node(self):
105 "node to string"
106 if self.nodes_modes == b'LEAF': # pylint: disable=E1101
107 text = "%s y=%r f=%r i=%r" % (
108 self.true_false,
109 self.target_weights, self.target_ids, # pylint: disable=E1101
110 self.target_nodeids) # pylint: disable=E1101
111 else:
112 text = "%s X%d %s %r" % (
113 self.true_false,
114 self.nodes_featureids, # pylint: disable=E1101
115 rule(self.nodes_modes), # pylint: disable=E1101
116 self.nodes_values) # pylint: disable=E1101
117 if self.nodes_hitrates and self.nodes_hitrates != 1:
118 text += " hi=%r" % self.nodes_hitrates
119 if self.nodes_missing_value_tracks_true:
120 text += " miss=%r" % (
121 self.nodes_missing_value_tracks_true)
122 return "%s%s" % (" " * self.depth, text)
124 def process_tree(atts, treeid):
125 "tree to string"
126 rows = ['treeid=%r' % treeid]
127 if 'base_values' in atts:
128 rows.append('base_value=%r' % atts['base_values'][treeid])
130 short = {}
131 for prefix in ['nodes', 'target', 'class']:
132 if ('%s_treeids' % prefix) not in atts:
133 continue
134 idx = [i for i in range(len(atts['%s_treeids' % prefix]))
135 if atts['%s_treeids' % prefix][i] == treeid]
136 for k, v in atts.items():
137 if k.startswith(prefix):
138 short[k] = [v[i] for i in idx]
140 nodes = OrderedDict()
141 for i in range(len(short['nodes_treeids'])):
142 nodes[i] = Node(i, short)
143 for i in range(len(short['target_treeids'])):
144 idn = short['target_nodeids'][i]
145 node = nodes[idn]
146 node.target_nodeids = idn
147 node.target_ids = short['target_ids'][i]
148 node.target_weights = short['target_weights'][i]
150 def iterate(nodes, node, depth=0, true_false=''):
151 node.depth = depth
152 node.true_false = true_false
153 yield node
154 if node.nodes_falsenodeids > 0:
155 for n in iterate(nodes, nodes[node.nodes_falsenodeids],
156 depth=depth + 1, true_false='F'):
157 yield n
158 for n in iterate(nodes, nodes[node.nodes_truenodeids],
159 depth=depth + 1, true_false='T'):
160 yield n
162 for node in iterate(nodes, nodes[0]):
163 rows.append(node.process_node())
164 return rows
166 if node.op_type != "TreeEnsembleRegressor":
167 raise NotImplementedError( # pragma: no cover
168 "Type %r cannot be displayed." % node.op_type)
169 d = {k: v['value'] for k, v in _var_as_dict(node)['atts'].items()}
170 atts = {}
171 for k, v in d.items():
172 atts[k] = v if isinstance(v, int) else list(v)
173 trees = list(sorted(set(atts['nodes_treeids'])))
174 rows = ['n_targets=%r' % atts['n_targets'],
175 'n_trees=%r' % len(trees)]
176 for tree in trees:
177 r = process_tree(atts, tree)
178 rows.append('----')
179 rows.extend(r)
181 return "\n".join(rows)
184def reorder_nodes_for_display(nodes, verbose=False):
185 """
186 Reorders the node with breadth first seach (BFS).
188 :param nodes: list of ONNX nodes
189 :param verbose: dislay intermediate informations
190 :return: reordered list of nodes
191 """
192 all_outputs = set()
193 all_inputs = set()
194 for node in nodes:
195 all_outputs |= set(node.output)
196 all_inputs |= set(node.input)
197 common = all_outputs & all_inputs
198 dnodes = OrderedDict()
199 successors = {}
200 predecessors = {}
201 for node in nodes:
202 node_name = node.name + "#" + "|".join(node.output)
203 dnodes[node_name] = node
204 successors[node_name] = set()
205 predecessors[node_name] = set()
206 for name in node.input:
207 predecessors[node_name].add(name)
208 if name not in successors:
209 successors[name] = set()
210 successors[name].add(node_name)
211 for name in node.output:
212 successors[node_name].add(name)
213 predecessors[name] = {node_name}
215 known = all_inputs - common
216 new_nodes = []
217 done = set()
219 def _find_sequence(node_name, known, done):
220 inputs = dnodes[node_name].input
221 if any(map(lambda i: i not in known, inputs)):
222 return []
224 res = [node_name]
225 while res[-1] in successors:
226 next_names = successors[res[-1]]
227 if res[-1] not in dnodes:
228 next_names = set(v for v in next_names if v not in known)
229 if len(next_names) == 1:
230 next_name = next_names.pop()
231 inputs = dnodes[next_name].input
232 if any(map(lambda i: i not in known, inputs)):
233 break
234 res.extend(next_name)
235 else:
236 break
237 else:
238 next_names = set(v for v in next_names if v not in done)
239 if len(next_names) == 1:
240 next_name = next_names.pop()
241 res.append(next_name)
242 else:
243 break
245 return [r for r in res if r in dnodes and r not in done]
247 while len(done) < len(nodes):
248 # possible
249 possibles = OrderedDict()
250 for k, v in dnodes.items():
251 if k in done:
252 continue
253 if predecessors[k] <= known:
254 possibles[k] = v
256 sequences = OrderedDict()
257 for k, v in possibles.items():
258 if k in done:
259 continue
260 sequences[k] = _find_sequence(k, known, done)
261 if verbose:
262 print("[reorder_nodes_for_display] sequence(%s)=%s" % (
263 k, ",".join(sequences[k])))
265 if len(sequences) == 0:
266 raise RuntimeError( # pragma: no cover
267 "Unexpected empty sequences (len(possibles)=%d, "
268 "len(done)=%d, len(nodes)=%d). This is usually due to "
269 "a name used both as result name and node node."
270 "" % (len(possibles), len(done), len(nodes)))
272 # find the best sequence
273 best = None
274 for k, v in sequences.items():
275 if best is None or len(v) > len(sequences[best]):
276 # if the sequence of successors is longer
277 best = k
278 elif len(v) == len(sequences[best]):
279 if len(new_nodes) > 0:
280 # then choose the next successor sharing input with
281 # previous output
282 so = set(new_nodes[-1].output)
283 first1 = dnodes[sequences[best][0]]
284 first2 = dnodes[v[0]]
285 if len(set(first1.input) & so) < len(set(first2.input) & so):
286 best = k
287 else:
288 first1 = dnodes[sequences[best][0]]
289 first2 = dnodes[v[0]]
290 if first1.op_type > first2.op_type:
291 best = k
292 elif (first1.op_type == first2.op_type and
293 first1.name > first2.name):
294 best = k
296 if best is None:
297 raise RuntimeError( # pragma: no cover
298 "Wrong implementation (len(sequence)=%d)." % len(sequences))
299 if verbose:
300 print("[reorder_nodes_for_display] BEST: sequence(%s)=%s" % (
301 best, ",".join(sequences[best])))
303 # process the sequence
304 for k in sequences[best]:
305 v = dnodes[k]
306 new_nodes.append(v)
307 done.add(k)
308 known |= set(v.output)
310 if len(new_nodes) != len(nodes):
311 raise RuntimeError( # pragma: no cover
312 "The returned new nodes are different. "
313 "len(nodes=%d != %d=len(new_nodes). done=\n%r"
314 "\n%s\n----------\n%s" % (
315 len(nodes), len(new_nodes), done,
316 "\n".join("%d - %s - %s - %s" % (
317 (n.name + "".join(n.output)) in done,
318 n.op_type, n.name, n.name + "".join(n.output))
319 for n in nodes),
320 "\n".join("%d - %s - %s - %s" % (
321 (n.name + "".join(n.output)) in done,
322 n.op_type, n.name, n.name + "".join(n.output))
323 for n in new_nodes)))
324 return new_nodes
327def _get_type(obj0):
328 obj = obj0
329 if hasattr(obj, 'data_type'):
330 if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101
331 hasattr(obj, 'float_data')):
332 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.FLOAT] # pylint: disable=E1101
333 if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101
334 hasattr(obj, 'double_data')):
335 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.DOUBLE] # pylint: disable=E1101
336 if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101
337 hasattr(obj, 'int64_data')):
338 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT64] # pylint: disable=E1101
339 if (obj.data_type == TensorProto.INT32 and # pylint: disable=E1101
340 hasattr(obj, 'int32_data')):
341 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT32] # pylint: disable=E1101
342 raise RuntimeError( # pragma: no cover
343 "Unable to guess type from %r." % obj0)
344 if hasattr(obj, 'type'):
345 obj = obj.type
346 if hasattr(obj, 'tensor_type'):
347 obj = obj.tensor_type
348 if hasattr(obj, 'elem_type'):
349 return TENSOR_TYPE_TO_NP_TYPE.get(obj.elem_type, '?')
350 raise RuntimeError( # pragma: no cover
351 "Unable to guess type from %r." % obj0)
354def _get_shape(obj):
355 obj0 = obj
356 if hasattr(obj, 'data_type'):
357 if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101
358 hasattr(obj, 'float_data')):
359 return (len(obj.float_data), )
360 if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101
361 hasattr(obj, 'double_data')):
362 return (len(obj.double_data), )
363 if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101
364 hasattr(obj, 'int64_data')):
365 return (len(obj.int64_data), )
366 if (obj.data_type == TensorProto.INT32 and # pylint: disable=E1101
367 hasattr(obj, 'int32_data')):
368 return (len(obj.int32_data), )
369 raise RuntimeError( # pragma: no cover
370 "Unable to guess type from %r." % obj0)
371 if hasattr(obj, 'type'):
372 obj = obj.type
373 if hasattr(obj, 'tensor_type'):
374 obj = obj.tensor_type
375 if hasattr(obj, 'shape'):
376 obj = obj.shape
377 dims = []
378 for d in obj.dim:
379 if hasattr(d, 'dim_value'):
380 dims.append(d.dim_value)
381 else:
382 dims.append(None)
383 return tuple(dims)
384 raise RuntimeError( # pragma: no cover
385 "Unable to guess type from %r." % obj0)
388def onnx_simple_text_plot(model, verbose=False, att_display=None,
389 add_links=False, recursive=False, functions=True):
390 """
391 Displays an ONNX graph into text.
393 :param model: ONNX graph
394 :param verbose: display debugging information
395 :param att_display: list of attributes to display, if None,
396 a default list if used
397 :param add_links: displays links of the right side
398 :param recursive: display subgraphs as well
399 :param functions: display functions as well
400 :return: str
402 An ONNX graph is printed the following way:
404 .. runpython::
405 :showcode:
406 :warningout: DeprecationWarning
408 import numpy
409 from sklearn.cluster import KMeans
410 from mlprodict.plotting.plotting import onnx_simple_text_plot
411 from mlprodict.onnx_conv import to_onnx
413 x = numpy.random.randn(10, 3)
414 y = numpy.random.randn(10)
415 model = KMeans(3)
416 model.fit(x, y)
417 onx = to_onnx(model, x.astype(numpy.float32),
418 target_opset=15)
419 text = onnx_simple_text_plot(onx, verbose=False)
420 print(text)
422 The same graphs with links.
424 .. runpython::
425 :showcode:
426 :warningout: DeprecationWarning
428 import numpy
429 from sklearn.cluster import KMeans
430 from mlprodict.plotting.plotting import onnx_simple_text_plot
431 from mlprodict.onnx_conv import to_onnx
433 x = numpy.random.randn(10, 3)
434 y = numpy.random.randn(10)
435 model = KMeans(3)
436 model.fit(x, y)
437 onx = to_onnx(model, x.astype(numpy.float32),
438 target_opset=15)
439 text = onnx_simple_text_plot(onx, verbose=False, add_links=True)
440 print(text)
442 Visually, it looks like the following:
444 .. gdot::
445 :script: DOT-SECTION
447 import numpy
448 from sklearn.cluster import KMeans
449 from mlprodict.onnxrt import OnnxInference
450 from mlprodict.onnx_conv import to_onnx
452 x = numpy.random.randn(10, 3)
453 y = numpy.random.randn(10)
454 model = KMeans(3)
455 model.fit(x, y)
456 model_onnx = to_onnx(model, x.astype(numpy.float32),
457 target_opset=15)
458 oinf = OnnxInference(model_onnx, inplace=False)
460 print("DOT-SECTION", oinf.to_dot())
461 """
462 if att_display is None:
463 att_display = [
464 'activations',
465 'align_corners',
466 'allowzero',
467 'alpha',
468 'auto_pad',
469 'axis',
470 'axes',
471 'batch_axis',
472 'batch_dims',
473 'beta',
474 'bias',
475 'blocksize',
476 'case_change_action',
477 'ceil_mode',
478 'center_point_box',
479 'clip',
480 'coordinate_transformation_mode',
481 'count_include_pad',
482 'cubic_coeff_a',
483 'decay_factor',
484 'detect_negative',
485 'detect_positive',
486 'dilation',
487 'dilations',
488 'direction',
489 'dtype',
490 'end',
491 'epsilon',
492 'equation',
493 'exclusive',
494 'exclude_outside',
495 'extrapolation_value',
496 'fmod',
497 'gamma',
498 'group',
499 'hidden_size',
500 'high',
501 'ignore_index',
502 'input_forget',
503 'is_case_sensitive',
504 'k',
505 'keepdims',
506 'kernel_shape',
507 'lambd',
508 'largest',
509 'layout',
510 'linear_before_reset',
511 'locale',
512 'low',
513 'max_gram_length',
514 'max_skip_count',
515 'mean',
516 'min_gram_length',
517 'mode',
518 'momentum',
519 'nearest_mode',
520 'ngram_counts',
521 'ngram_indexes',
522 'noop_with_empty_axes',
523 'norm_coefficient',
524 'norm_coefficient_post',
525 'num_scan_inputs',
526 'output_height',
527 'output_padding',
528 'output_shape',
529 'output_width',
530 'p',
531 'padding_mode',
532 'pads',
533 'perm',
534 'pooled_shape',
535 'reduction',
536 'reverse',
537 'sample_size',
538 'sampling_ratio',
539 'scale',
540 'scan_input_axes',
541 'scan_input_directions',
542 'scan_output_axes',
543 'scan_output_directions',
544 'seed',
545 'select_last_index',
546 'size',
547 'sorted',
548 'spatial_scale',
549 'start',
550 'storage_order',
551 'strides',
552 'time_axis',
553 'to',
554 'training_mode',
555 'transA',
556 'transB',
557 'type',
558 'upper',
559 'xs',
560 'y',
561 'zs',
562 ]
564 def str_node(indent, node):
565 atts = []
566 if hasattr(node, 'attribute'):
567 for att in node.attribute:
568 if att.name in att_display:
569 if att.type == AttributeProto.INT: # pylint: disable=E1101
570 atts.append("%s=%d" % (att.name, att.i))
571 elif att.type == AttributeProto.FLOAT: # pylint: disable=E1101
572 atts.append("%s=%1.2f" % (att.name, att.f))
573 elif att.type == AttributeProto.INTS: # pylint: disable=E1101
574 atts.append("%s=%s" % (att.name, str(
575 list(att.ints)).replace(" ", "")))
576 inputs = list(node.input)
577 if len(atts) > 0:
578 inputs.extend(atts)
579 if node.domain in ('', 'ai.onnx.ml'):
580 domain = ''
581 else:
582 domain = '[%s]' % node.domain
583 return "%s%s%s(%s) -> %s" % (
584 " " * indent, node.op_type, domain,
585 ", ".join(inputs), ", ".join(node.output))
587 rows = []
588 if hasattr(model, 'opset_import'):
589 for opset in model.opset_import:
590 rows.append("opset: domain=%r version=%r" % (
591 opset.domain, opset.version))
592 if hasattr(model, 'graph'):
593 main_model = model
594 model = model.graph
595 else:
596 main_model = None
598 # inputs
599 line_name_new = {}
600 line_name_in = {}
601 for inp in model.input:
602 if isinstance(inp, str):
603 rows.append("input: %r" % inp)
604 else:
605 line_name_new[inp.name] = len(rows)
606 rows.append("input: name=%r type=%r shape=%r" % (
607 inp.name, _get_type(inp), _get_shape(inp)))
608 # initializer
609 if hasattr(model, 'initializer'):
610 for init in model.initializer:
611 if numpy.prod(_get_shape(init)) < 5:
612 content = " -- %r" % to_array(init).ravel()
613 else:
614 content = ""
615 line_name_new[init.name] = len(rows)
616 rows.append("init: name=%r type=%r shape=%r%s" % (
617 init.name, _get_type(init), _get_shape(init), content))
619 # successors, predecessors
620 successors = {}
621 predecessors = {}
622 subgraphs = []
623 for node in model.node:
624 node_name = node.name + "#" + "|".join(node.output)
625 successors[node_name] = []
626 predecessors[node_name] = []
627 for name in node.input:
628 predecessors[node_name].append(name)
629 if name not in successors:
630 successors[name] = []
631 successors[name].append(node_name)
632 for name in node.output:
633 successors[node_name].append(name)
634 predecessors[name] = [node_name]
635 if recursive and node.op_type in {'If', 'Scan', 'Loop'}:
636 for att in node.attribute:
637 if att.name not in {'body', 'else_branch', 'then_branch'}:
638 continue
639 subgraphs.append((node, att.name, att.g))
641 # walk through nodes
642 init_names = set()
643 indents = {}
644 for inp in model.input:
645 if isinstance(inp, str):
646 indents[inp] = 0
647 init_names.add(inp)
648 else:
649 indents[inp.name] = 0
650 init_names.add(inp.name)
651 if hasattr(model, 'initializer'):
652 for init in model.initializer:
653 indents[init.name] = 0
654 init_names.add(init.name)
656 nodes = reorder_nodes_for_display(model.node, verbose=verbose)
658 previous_indent = None
659 previous_out = None
660 previous_in = None
661 for node in nodes:
662 add_break = False
663 name = node.name + "#" + "|".join(node.output)
664 if name in indents:
665 indent = indents[name]
666 if previous_indent is not None and indent < previous_indent:
667 if verbose:
668 print("[onnx_simple_text_plot] break1 %s" % node.op_type)
669 add_break = True
670 elif previous_in is not None and set(node.input) == previous_in:
671 indent = previous_indent
672 else:
673 inds = [indents.get(i, 0)
674 for i in node.input if i not in init_names]
675 if len(inds) == 0:
676 indent = 0
677 else:
678 mi = min(inds)
679 indent = mi
680 if previous_indent is not None and indent < previous_indent:
681 if verbose:
682 print( # pragma: no cover
683 "[onnx_simple_text_plot] break2 %s" %
684 node.op_type)
685 add_break = True
686 if not add_break and previous_out is not None:
687 if len(set(node.input) & previous_out) == 0:
688 if verbose:
689 print("[onnx_simple_text_plot] break3 %s" %
690 node.op_type)
691 add_break = True
692 indent = 0
694 if add_break and verbose:
695 print("[onnx_simple_text_plot] add break")
696 for n in node.input:
697 if n in line_name_in:
698 line_name_in[n].append(len(rows))
699 else:
700 line_name_in[n] = [len(rows)]
701 for n in node.output:
702 line_name_new[n] = len(rows)
703 rows.append(str_node(indent, node))
704 indents[name] = indent
706 for i, o in enumerate(node.output):
707 indents[o] = indent + 1
709 previous_indent = indents[name]
710 previous_out = set(node.output)
711 previous_in = set(node.input)
713 # outputs
714 for out in model.output:
715 if isinstance(out, str):
716 if out in line_name_in:
717 line_name_in[out].append(len(rows))
718 else:
719 line_name_in[out] = [len(rows)]
720 rows.append("output: name=%r type=%s shape=%s" % (
721 out, '?', '?'))
722 else:
723 if out.name in line_name_in:
724 line_name_in[out.name].append(len(rows))
725 else:
726 line_name_in[out.name] = [len(rows)]
727 rows.append("output: name=%r type=%r shape=%r" % (
728 out.name, _get_type(out), _get_shape(out)))
730 if add_links:
732 def _mark_link(rows, lengths, r1, r2, d):
733 maxl = max(lengths[r1], lengths[r2]) + d * 2
734 maxl = max(maxl, max(len(rows[r]) for r in range(r1, r2 + 1))) + 2
736 if rows[r1][-1] == '|':
737 p1, p2 = rows[r1][:lengths[r1] + 2], rows[r1][lengths[r1] + 2:]
738 rows[r1] = p1 + p2.replace(' ', '-')
739 rows[r1] += ("-" * (maxl - len(rows[r1]) - 1)) + "+"
741 if rows[r2][-1] == " ":
742 rows[r2] += "<"
743 elif rows[r2][-1] == '|':
744 if "<" not in rows[r2]:
745 p = lengths[r2]
746 rows[r2] = rows[r2][:p] + '<' + rows[r2][p + 1:]
747 p1, p2 = rows[r2][:lengths[r2] + 2], rows[r2][lengths[r2] + 2:]
748 rows[r2] = p1 + p2.replace(' ', '-')
749 rows[r2] += ("-" * (maxl - len(rows[r2]) - 1)) + "+"
751 for r in range(r1 + 1, r2):
752 if len(rows[r]) < maxl:
753 rows[r] += " " * (maxl - len(rows[r]) - 1)
754 rows[r] += "|"
756 diffs = []
757 for n, r1 in line_name_new.items():
758 if n not in line_name_in:
759 continue
760 r2s = line_name_in[n]
761 for r2 in r2s:
762 if r1 >= r2:
763 continue
764 diffs.append((r2 - r1, (n, r1, r2)))
765 diffs.sort()
766 for i in range(len(rows)): # pylint: disable=C0200
767 rows[i] += " "
768 lengths = [len(r) for r in rows]
770 for d, (n, r1, r2) in diffs:
771 if d == 1 and len(line_name_in[n]) == 1:
772 # no line for link to the next node
773 continue
774 _mark_link(rows, lengths, r1, r2, d)
776 # subgraphs
777 for node, name, g in subgraphs:
778 rows.append('----- subgraph ---- %s - %s - att.%s=' % (
779 node.op_type, node.name, name))
780 res = onnx_simple_text_plot(
781 g, verbose=verbose, att_display=att_display,
782 add_links=add_links, recursive=recursive)
783 rows.append(res)
785 # functions
786 if functions and main_model is not None:
787 for fct in main_model.functions:
788 rows.append('----- function name=%s domain=%s' % (
789 fct.name, fct.domain))
790 res = onnx_simple_text_plot(
791 fct, verbose=verbose, att_display=att_display,
792 add_links=add_links, recursive=recursive,
793 functions=False)
794 rows.append(res)
796 return "\n".join(rows)
799def onnx_text_plot_io(model, verbose=False, att_display=None):
800 """
801 Displays information about input and output types.
803 :param model: ONNX graph
804 :param verbose: display debugging information
805 :return: str
807 An ONNX graph is printed the following way:
809 .. runpython::
810 :showcode:
811 :warningout: DeprecationWarning
813 import numpy
814 from sklearn.cluster import KMeans
815 from mlprodict.plotting.plotting import onnx_text_plot_io
816 from mlprodict.onnx_conv import to_onnx
818 x = numpy.random.randn(10, 3)
819 y = numpy.random.randn(10)
820 model = KMeans(3)
821 model.fit(x, y)
822 onx = to_onnx(model, x.astype(numpy.float32),
823 target_opset=15)
824 text = onnx_text_plot_io(onx, verbose=False)
825 print(text)
826 """
827 rows = []
828 if hasattr(model, 'opset_import'):
829 for opset in model.opset_import:
830 rows.append("opset: domain=%r version=%r" % (
831 opset.domain, opset.version))
832 if hasattr(model, 'graph'):
833 model = model.graph
835 # inputs
836 for inp in model.input:
837 rows.append("input: name=%r type=%r shape=%r" % (
838 inp.name, _get_type(inp), _get_shape(inp)))
839 # initializer
840 for init in model.initializer:
841 rows.append("init: name=%r type=%r shape=%r" % (
842 init.name, _get_type(init), _get_shape(init)))
843 # outputs
844 for out in model.output:
845 rows.append("output: name=%r type=%r shape=%r" % (
846 out.name, _get_type(out), _get_shape(out)))
847 return "\n".join(rows)