Coverage for mlprodict/tools/graphs.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Alternative to dot to display a graph.
5.. versionadded:: 0.7
6"""
7import pprint
8import hashlib
9import numpy
10import onnx
13def make_hash_bytes(data, length=20):
14 """
15 Creates a hash of length *length*.
16 """
17 m = hashlib.sha256()
18 m.update(data)
19 res = m.hexdigest()[:length]
20 return res
23class AdjacencyGraphDisplay:
24 """
25 Structure which contains the necessary information to
26 display a graph using an adjacency matrix.
28 .. versionadded:: 0.7
29 """
31 class Action:
32 "One action to do."
34 def __init__(self, x, y, kind, label, orientation=None):
35 self.x = x
36 self.y = y
37 self.kind = kind
38 self.label = label
39 self.orientation = orientation
41 def __repr__(self):
42 "usual"
43 return "%s(%r, %r, %r, %r, %r)" % (
44 self.__class__.__name__,
45 self.x, self.y, self.kind, self.label,
46 self.orientation)
48 def __init__(self):
49 self.actions = []
51 def __iter__(self):
52 "Iterates over actions."
53 for act in self.actions:
54 yield act
56 def __str__(self):
57 "usual"
58 rows = ["%s(" % self.__class__.__name__]
59 for act in self:
60 rows.append(" %r" % act)
61 rows.append(")")
62 return "\n".join(rows)
64 def add(self, x, y, kind, label, orientation=None):
65 """
66 Adds an action to display the graph.
68 :param x: x coordinate
69 :param y: y coordinate
70 :param kind: `'cross'` or `'text'`
71 :param label: specific to kind
72 :param orientation: a 2-uple `(i,j)` where *i* or *j* in `{-1,0,1}`
73 """
74 if kind not in {'cross', 'text'}:
75 raise ValueError( # pragma: no cover
76 "Unexpected value for kind %r." % kind)
77 if kind == 'cross' and label[0] not in {'I', 'O'}:
78 raise ValueError( # pragma: no cover
79 "kind=='cross' and label[0]=%r not in {'I','O'}." % label)
80 if not isinstance(label, str):
81 raise TypeError( # pragma: no cover
82 "Unexpected label type %r." % type(label))
83 self.actions.append(
84 AdjacencyGraphDisplay.Action(x, y, kind, label=label,
85 orientation=orientation))
87 def to_text(self):
88 """
89 Displays the graph as a single string.
90 See @see fn onnx2bigraph to see how the result
91 looks like.
93 :return: str
94 """
95 mat = {}
96 for act in self:
97 if act.kind == 'cross':
98 if act.orientation != (1, 0):
99 raise NotImplementedError( # pragma: no cover
100 "Orientation for 'cross' must be (1, 0) not %r."
101 "" % act.orientation)
102 if len(act.label) == 1:
103 mat[act.x * 3, act.y] = act.label
104 elif len(act.label) == 2:
105 mat[act.x * 3, act.y] = act.label[0]
106 mat[act.x * 3 + 1, act.y] = act.label[1]
107 else:
108 raise NotImplementedError(
109 "Unable to display long cross label (%r)."
110 "" % act.label)
111 elif act.kind == 'text':
112 x = act.x * 3
113 y = act.y
114 orient = act.orientation
115 charset = list(act.label if max(orient) == 1
116 else reversed(act.label))
117 for c in charset:
118 mat[x, y] = c
119 x += orient[0]
120 y += orient[1]
121 else:
122 raise ValueError( # pragma: no cover
123 "Unexpected kind value %r." % act.kind)
125 min_i = min(k[0] for k in mat)
126 min_j = min(k[1] for k in mat)
127 mat2 = {}
128 for k, v in mat.items():
129 mat2[k[0] - min_i, k[1] - min_j] = v
131 max_x = max(k[0] for k in mat2)
132 max_y = max(k[1] for k in mat2)
134 mat = numpy.full((max_y + 1, max_x + 1), ' ')
135 for k, v in mat2.items():
136 mat[k[1], k[0]] = v
137 rows = []
138 for i in range(mat.shape[0]):
139 rows.append(''.join(mat[i]))
140 return "\n".join(rows)
143class BiGraph:
144 """
145 BiGraph representation.
147 .. versionadded:: 0.7
148 """
150 class A:
151 "Additional information for a vertex or an edge."
153 def __init__(self, kind):
154 self.kind = kind
156 def __repr__(self):
157 return "A(%r)" % self.kind
159 class B:
160 "Additional information for a vertex or an edge."
162 def __init__(self, name, content, onnx_name):
163 if not isinstance(content, str):
164 raise TypeError( # pragma: no cover
165 "content must be str not %r." % type(content))
166 self.name = name
167 self.content = content
168 self.onnx_name = onnx_name
170 def __repr__(self):
171 return "B(%r, %r, %r)" % (self.name, self.content, self.onnx_name)
173 def __init__(self, v0, v1, edges):
174 """
175 :param v0: first set of vertices (dictionary)
176 :param v1: second set of vertices (dictionary)
177 :param edges: edges
178 """
179 if not isinstance(v0, dict):
180 raise TypeError("v0 must be a dictionary.")
181 if not isinstance(v1, dict):
182 raise TypeError("v0 must be a dictionary.")
183 if not isinstance(edges, dict):
184 raise TypeError("edges must be a dictionary.")
185 self.v0 = v0
186 self.v1 = v1
187 self.edges = edges
188 common = set(self.v0).intersection(set(self.v1))
189 if len(common) > 0:
190 raise ValueError(
191 "Sets v1 and v2 have common nodes (forbidden): %r." % common)
192 for a, b in edges:
193 if a in v0 and b in v1:
194 continue
195 if a in v1 and b in v0:
196 continue
197 if b in v1:
198 # One operator is missing one input.
199 # We add one.
200 self.v0[a] = BiGraph.A('ERROR')
201 continue
202 raise ValueError(
203 "Edges (%r, %r) not found among the vertices." % (a, b))
205 def __str__(self):
206 """
207 usual
208 """
209 return "%s(%d v., %d v., %d edges)" % (
210 self.__class__.__name__, len(self.v0),
211 len(self.v1), len(self.edges))
213 def __iter__(self):
214 """
215 Iterates over all vertices and edges.
216 It produces 3-uples:
218 * 0, name, A: vertices in *v0*
219 * 1, name, A: vertices in *v1*
220 * -1, name, A: edges
221 """
222 for k, v in self.v0.items():
223 yield 0, k, v
224 for k, v in self.v1.items():
225 yield 1, k, v
226 for k, v in self.edges.items():
227 yield -1, k, v
229 def __getitem__(self, key):
230 """
231 Returns a vertex is key is a string or an edge
232 if it is a tuple.
234 :param key: vertex or edge name
235 :return: value
236 """
237 if isinstance(key, tuple):
238 return self.edges[key]
239 if key in self.v0:
240 return self.v0[key]
241 return self.v1[key]
243 def order_vertices(self):
244 """
245 Orders the vertices from the input to the output.
247 :return: dictionary `{vertex name: order}`
248 """
249 order = {}
250 for v in self.v0:
251 order[v] = 0
252 for v in self.v1:
253 order[v] = 0
254 modif = 1
255 n_iter = 0
256 while modif > 0:
257 modif = 0
258 for a, b in self.edges:
259 if order[b] <= order[a]:
260 order[b] = order[a] + 1
261 modif += 1
262 n_iter += 1
263 if n_iter > len(order):
264 break
265 if modif > 0:
266 raise RuntimeError(
267 "The graph has a cycle.\n%s" % pprint.pformat(
268 self.edges))
269 return order
271 def adjacency_matrix(self):
272 """
273 Builds an adjacency matrix.
275 :return: matrix, list of row vertices, list of column vertices
276 """
277 order = self.order_vertices()
278 ord_v0 = [(v, k) for k, v in order.items() if k in self.v0]
279 ord_v1 = [(v, k) for k, v in order.items() if k in self.v1]
280 ord_v0.sort()
281 ord_v1.sort()
282 row = [b for a, b in ord_v0]
283 col = [b for a, b in ord_v1]
284 row_id = {b: i for i, b in enumerate(row)}
285 col_id = {b: i for i, b in enumerate(col)}
286 matrix = numpy.zeros((len(row), len(col)), dtype=numpy.int32)
287 for a, b in self.edges:
288 if a in row_id:
289 matrix[row_id[a], col_id[b]] = 1
290 else:
291 matrix[row_id[b], col_id[a]] = 1
292 return matrix, row, col
294 def display_structure(self, grid=5, distance=5):
295 """
296 Creates a display structure which contains
297 all the necessary steps to display a graph.
299 :param grid: align text to this grid
300 :param distance: distance to the text
301 :return: instance of @see cl AdjacencyGraphDisplay
302 """
303 def adjust(c, way):
304 if way == 1:
305 d = grid * ((c + distance * 2 - (grid // 2 + 1)) // grid)
306 else:
307 d = -grid * ((-c + distance * 2 - (grid // 2 + 1)) // grid)
308 return d
310 matrix, row, col = self.adjacency_matrix()
311 row_id = {b: i for i, b in enumerate(row)}
312 col_id = {b: i for i, b in enumerate(col)}
314 interval_y_min = numpy.zeros((matrix.shape[0], ), dtype=numpy.int32)
315 interval_y_max = numpy.zeros((matrix.shape[0], ), dtype=numpy.int32)
316 interval_x_min = numpy.zeros((matrix.shape[1], ), dtype=numpy.int32)
317 interval_x_max = numpy.zeros((matrix.shape[1], ), dtype=numpy.int32)
318 interval_y_min[:] = max(matrix.shape)
319 interval_x_min[:] = max(matrix.shape)
321 graph = AdjacencyGraphDisplay()
322 for key, value in self.edges.items():
323 if key[0] in row_id:
324 y = row_id[key[0]]
325 x = col_id[key[1]]
326 else:
327 x = col_id[key[0]]
328 y = row_id[key[1]]
329 graph.add(x, y, 'cross', label=value.kind, orientation=(1, 0))
330 if x < interval_y_min[y]:
331 interval_y_min[y] = x
332 if x > interval_y_max[y]:
333 interval_y_max[y] = x
334 if y < interval_x_min[x]:
335 interval_x_min[x] = y
336 if y > interval_x_max[x]:
337 interval_x_max[x] = y
339 for k, v in self.v0.items():
340 y = row_id[k]
341 x = adjust(interval_y_min[y], -1)
342 graph.add(x, y, 'text', label=v.kind, orientation=(-1, 0))
343 x = adjust(interval_y_max[y], 1)
344 graph.add(x, y, 'text', label=k, orientation=(1, 0))
346 for k, v in self.v1.items():
347 x = col_id[k]
348 y = adjust(interval_x_min[x], -1)
349 graph.add(x, y, 'text', label=v.kind, orientation=(0, -1))
350 y = adjust(interval_x_max[x], 1)
351 graph.add(x, y, 'text', label=k, orientation=(0, 1))
353 return graph
355 def order(self):
356 """
357 Order nodes. Depth first.
358 Returns a sequence of keys of mixed *v1*, *v2*.
359 """
360 # Creates forwards nodes.
361 forwards = {}
362 backwards = {}
363 for k in self.v0:
364 forwards[k] = []
365 backwards[k] = []
366 for k in self.v1:
367 forwards[k] = []
368 backwards[k] = []
369 modif = True
370 while modif:
371 modif = False
372 for edge in self.edges:
373 a, b = edge
374 if b not in forwards[a]:
375 forwards[a].append(b)
376 modif = True
377 if a not in backwards[b]:
378 backwards[b].append(a)
379 modif = True
381 # roots
382 roots = [b for b, backs in backwards.items() if len(backs) == 0]
383 if len(roots) == 0:
384 raise RuntimeError( # pragma: no cover
385 "This graph has cycles. Not allowed.")
387 # ordering
388 order = {}
389 stack = roots
390 while len(stack) > 0:
391 node = stack.pop()
392 order[node] = len(order)
393 w = forwards[node]
394 if len(w) == 0:
395 continue
396 last = w.pop()
397 stack.append(last)
399 return order
401 def summarize(self):
402 """
403 Creates a text summary of the graph.
404 """
405 order = self.order()
406 keys = [(o, k) for k, o in order.items()]
407 keys.sort()
409 rows = []
410 for _, k in keys:
411 if k in self.v1:
412 rows.append(str(self.v1[k]))
413 return "\n".join(rows)
415 @staticmethod
416 def _onnx2bigraph_basic(model_onnx, recursive=False):
417 """
418 Implements graph type `'basic'` for function
419 @see fn onnx2bigraph.
420 """
422 if recursive:
423 raise NotImplementedError( # pragma: no cover
424 "Option recursive=True is not implemented yet.")
425 v0 = {}
426 v1 = {}
427 edges = {}
429 # inputs
430 for i, o in enumerate(model_onnx.graph.input):
431 v0[o.name] = BiGraph.A('Input-%d' % i)
432 for i, o in enumerate(model_onnx.graph.output):
433 v0[o.name] = BiGraph.A('Output-%d' % i)
434 for o in model_onnx.graph.initializer:
435 v0[o.name] = BiGraph.A('Init')
436 for n in model_onnx.graph.node:
437 nname = n.name if len(n.name) > 0 else "id%d" % id(n)
438 v1[nname] = BiGraph.A(n.op_type)
439 for i, o in enumerate(n.input):
440 c = str(i) if i < 10 else "+"
441 nname = n.name if len(n.name) > 0 else "id%d" % id(n)
442 edges[o, nname] = BiGraph.A('I%s' % c)
443 for i, o in enumerate(n.output):
444 c = str(i) if i < 10 else "+"
445 if o not in v0:
446 v0[o] = BiGraph.A('inout')
447 nname = n.name if len(n.name) > 0 else "id%d" % id(n)
448 edges[nname, o] = BiGraph.A('O%s' % c)
450 return BiGraph(v0, v1, edges)
452 @staticmethod
453 def _onnx2bigraph_simplified(model_onnx, recursive=False):
454 """
455 Implements graph type `'simplified'` for function
456 @see fn onnx2bigraph.
457 """
458 if recursive:
459 raise NotImplementedError( # pragma: no cover
460 "Option recursive=True is not implemented yet.")
461 v0 = {}
462 v1 = {}
463 edges = {}
465 # inputs
466 for o in model_onnx.graph.input:
467 v0["I%d" % len(v0)] = BiGraph.B(
468 'In', make_hash_bytes(o.type.SerializeToString(), 2), o.name)
469 for o in model_onnx.graph.output:
470 v0["O%d" % len(v0)] = BiGraph.B(
471 'Ou', make_hash_bytes(o.type.SerializeToString(), 2), o.name)
472 for o in model_onnx.graph.initializer:
473 v0["C%d" % len(v0)] = BiGraph.B(
474 'Cs', make_hash_bytes(o.raw_data, 10), o.name)
476 names_v0 = {v.onnx_name: k for k, v in v0.items()}
478 for n in model_onnx.graph.node:
479 key_node = "N%d" % len(v1)
480 if len(n.attribute) > 0:
481 ats = []
482 for at in n.attribute:
483 ats.append(at.SerializeToString())
484 ct = make_hash_bytes(b"".join(ats), 10)
485 else:
486 ct = ""
487 v1[key_node] = BiGraph.B(
488 n.op_type, ct, n.name)
489 for o in n.input:
490 key_in = names_v0[o]
491 edges[key_in, key_node] = BiGraph.A('I')
492 for o in n.output:
493 if o not in names_v0:
494 key = "R%d" % len(v0)
495 v0[key] = BiGraph.B('Re', n.op_type, o)
496 names_v0[o] = key
497 edges[key_node, key] = BiGraph.A('O')
499 return BiGraph(v0, v1, edges)
501 @staticmethod
502 def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print):
503 """
504 Computes a distance between two ONNX graphs. They must not
505 be too big otherwise this function might take for ever.
506 The function relies on package :epkg:`mlstatpy`.
508 :param onx1: first graph (ONNX graph or model file name)
509 :param onx2: second graph (ONNX graph or model file name)
510 :param verbose: verbosity
511 :param fLOG: logging function
512 :return: distance and differences
514 .. warning::
516 This is very experimental and very slow.
518 .. versionadded:: 0.7
519 """
520 from mlstatpy.graph.graph_distance import GraphDistance
522 if isinstance(onx1, str):
523 onx1 = onnx.load(onx1)
524 if isinstance(onx2, str):
525 onx2 = onnx.load(onx2)
527 def make_hash(init):
528 return make_hash_bytes(init.raw_data)
530 def build_graph(onx):
531 edges = []
532 labels = {}
533 for node in onx.graph.node:
534 if len(node.name) == 0:
535 name = str(id(node))
536 else:
537 name = node.name
538 for i in node.input:
539 edges.append((i, name))
540 for p, i in enumerate(node.output):
541 edges.append((name, i))
542 labels[i] = "%s:%d" % (node.op_type, p)
543 labels[name] = node.op_type
544 for init in onx.graph.initializer:
545 labels[init.name] = make_hash(init)
547 g = GraphDistance(edges, vertex_label=labels)
548 return g
550 g1 = build_graph(onx1)
551 g2 = build_graph(onx2)
553 dist, gdist = g1.distance_matching_graphs_paths(
554 g2, verbose=verbose, fLOG=fLOG, use_min=False)
555 return dist, gdist
558def onnx2bigraph(model_onnx, recursive=False, graph_type='basic'):
559 """
560 Converts an ONNX graph into a graph representation,
561 edges, vertices.
563 :param model_onnx: ONNX graph
564 :param recursive: dig into subgraphs too
565 :param graph_type: kind of graph it creates
566 :return: see @cl BiGraph
568 About *graph_type*:
570 * `'basic'`: basic graph structure, it returns an instance
571 of type @see cl BiGraph. The structure keeps the original
572 names.
573 * `'simplified'`: simplifed graph structure, names are removed
574 as they could be prevent the algorithm to find any matching.
576 .. exref::
577 :title: Displays an ONNX graph as text
579 The function uses an adjacency matrix of the graph.
580 Results are displayed by rows, operator by columns.
581 Results kinds are shows on the left,
582 their names on the right. Operator types are displayed
583 on the top, their names on the bottom.
585 .. runpython::
586 :showcode:
588 import numpy
589 from mlprodict.onnx_conv import to_onnx
590 from mlprodict import __max_supported_opset__ as opv
591 from mlprodict.tools.graphs import onnx2bigraph
592 from mlprodict.npy.xop import loadop
594 OnnxAdd, OnnxSub = loadop('Add', 'Sub')
596 idi = numpy.identity(2).astype(numpy.float32)
597 A = OnnxAdd('X', idi, op_version=opv)
598 B = OnnxSub(A, 'W', output_names=['Y'], op_version=opv)
599 onx = B.to_onnx({'X': idi, 'W': idi})
600 bigraph = onnx2bigraph(onx)
601 graph = bigraph.display_structure()
602 text = graph.to_text()
603 print(text)
605 .. versionadded:: 0.7
606 """
607 if graph_type == 'basic':
608 return BiGraph._onnx2bigraph_basic(
609 model_onnx, recursive=recursive)
610 if graph_type == 'simplified':
611 return BiGraph._onnx2bigraph_simplified(
612 model_onnx, recursive=recursive)
613 raise ValueError(
614 "Unknown value for graph_type=%r." % graph_type)
617def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print):
618 """
619 Computes a distance between two ONNX graphs. They must not
620 be too big otherwise this function might take for ever.
621 The function relies on package :epkg:`mlstatpy`.
623 :param onx1: first graph (ONNX graph or model file name)
624 :param onx2: second graph (ONNX graph or model file name)
625 :param verbose: verbosity
626 :param fLOG: logging function
627 :return: distance and differences
629 .. warning::
631 This is very experimental and very slow.
633 .. versionadded:: 0.7
634 """
635 return BiGraph.onnx_graph_distance(onx1, onx2, verbose=verbose, fLOG=fLOG)