Coverage for mlprodict/tools/graphs.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

337 statements  

1""" 

2@file 

3@brief Alternative to dot to display a graph. 

4 

5.. versionadded:: 0.7 

6""" 

7import pprint 

8import hashlib 

9import numpy 

10import onnx 

11 

12 

13def make_hash_bytes(data, length=20): 

14 """ 

15 Creates a hash of length *length*. 

16 """ 

17 m = hashlib.sha256() 

18 m.update(data) 

19 res = m.hexdigest()[:length] 

20 return res 

21 

22 

23class AdjacencyGraphDisplay: 

24 """ 

25 Structure which contains the necessary information to 

26 display a graph using an adjacency matrix. 

27 

28 .. versionadded:: 0.7 

29 """ 

30 

31 class Action: 

32 "One action to do." 

33 

34 def __init__(self, x, y, kind, label, orientation=None): 

35 self.x = x 

36 self.y = y 

37 self.kind = kind 

38 self.label = label 

39 self.orientation = orientation 

40 

41 def __repr__(self): 

42 "usual" 

43 return "%s(%r, %r, %r, %r, %r)" % ( 

44 self.__class__.__name__, 

45 self.x, self.y, self.kind, self.label, 

46 self.orientation) 

47 

48 def __init__(self): 

49 self.actions = [] 

50 

51 def __iter__(self): 

52 "Iterates over actions." 

53 for act in self.actions: 

54 yield act 

55 

56 def __str__(self): 

57 "usual" 

58 rows = ["%s(" % self.__class__.__name__] 

59 for act in self: 

60 rows.append(" %r" % act) 

61 rows.append(")") 

62 return "\n".join(rows) 

63 

64 def add(self, x, y, kind, label, orientation=None): 

65 """ 

66 Adds an action to display the graph. 

67 

68 :param x: x coordinate 

69 :param y: y coordinate 

70 :param kind: `'cross'` or `'text'` 

71 :param label: specific to kind 

72 :param orientation: a 2-uple `(i,j)` where *i* or *j* in `{-1,0,1}` 

73 """ 

74 if kind not in {'cross', 'text'}: 

75 raise ValueError( # pragma: no cover 

76 "Unexpected value for kind %r." % kind) 

77 if kind == 'cross' and label[0] not in {'I', 'O'}: 

78 raise ValueError( # pragma: no cover 

79 "kind=='cross' and label[0]=%r not in {'I','O'}." % label) 

80 if not isinstance(label, str): 

81 raise TypeError( # pragma: no cover 

82 "Unexpected label type %r." % type(label)) 

83 self.actions.append( 

84 AdjacencyGraphDisplay.Action(x, y, kind, label=label, 

85 orientation=orientation)) 

86 

87 def to_text(self): 

88 """ 

89 Displays the graph as a single string. 

90 See @see fn onnx2bigraph to see how the result 

91 looks like. 

92 

93 :return: str 

94 """ 

95 mat = {} 

96 for act in self: 

97 if act.kind == 'cross': 

98 if act.orientation != (1, 0): 

99 raise NotImplementedError( # pragma: no cover 

100 "Orientation for 'cross' must be (1, 0) not %r." 

101 "" % act.orientation) 

102 if len(act.label) == 1: 

103 mat[act.x * 3, act.y] = act.label 

104 elif len(act.label) == 2: 

105 mat[act.x * 3, act.y] = act.label[0] 

106 mat[act.x * 3 + 1, act.y] = act.label[1] 

107 else: 

108 raise NotImplementedError( 

109 "Unable to display long cross label (%r)." 

110 "" % act.label) 

111 elif act.kind == 'text': 

112 x = act.x * 3 

113 y = act.y 

114 orient = act.orientation 

115 charset = list(act.label if max(orient) == 1 

116 else reversed(act.label)) 

117 for c in charset: 

118 mat[x, y] = c 

119 x += orient[0] 

120 y += orient[1] 

121 else: 

122 raise ValueError( # pragma: no cover 

123 "Unexpected kind value %r." % act.kind) 

124 

125 min_i = min(k[0] for k in mat) 

126 min_j = min(k[1] for k in mat) 

127 mat2 = {} 

128 for k, v in mat.items(): 

129 mat2[k[0] - min_i, k[1] - min_j] = v 

130 

131 max_x = max(k[0] for k in mat2) 

132 max_y = max(k[1] for k in mat2) 

133 

134 mat = numpy.full((max_y + 1, max_x + 1), ' ') 

135 for k, v in mat2.items(): 

136 mat[k[1], k[0]] = v 

137 rows = [] 

138 for i in range(mat.shape[0]): 

139 rows.append(''.join(mat[i])) 

140 return "\n".join(rows) 

141 

142 

143class BiGraph: 

144 """ 

145 BiGraph representation. 

146 

147 .. versionadded:: 0.7 

148 """ 

149 

150 class A: 

151 "Additional information for a vertex or an edge." 

152 

153 def __init__(self, kind): 

154 self.kind = kind 

155 

156 def __repr__(self): 

157 return "A(%r)" % self.kind 

158 

159 class B: 

160 "Additional information for a vertex or an edge." 

161 

162 def __init__(self, name, content, onnx_name): 

163 if not isinstance(content, str): 

164 raise TypeError( # pragma: no cover 

165 "content must be str not %r." % type(content)) 

166 self.name = name 

167 self.content = content 

168 self.onnx_name = onnx_name 

169 

170 def __repr__(self): 

171 return "B(%r, %r, %r)" % (self.name, self.content, self.onnx_name) 

172 

173 def __init__(self, v0, v1, edges): 

174 """ 

175 :param v0: first set of vertices (dictionary) 

176 :param v1: second set of vertices (dictionary) 

177 :param edges: edges 

178 """ 

179 if not isinstance(v0, dict): 

180 raise TypeError("v0 must be a dictionary.") 

181 if not isinstance(v1, dict): 

182 raise TypeError("v0 must be a dictionary.") 

183 if not isinstance(edges, dict): 

184 raise TypeError("edges must be a dictionary.") 

185 self.v0 = v0 

186 self.v1 = v1 

187 self.edges = edges 

188 common = set(self.v0).intersection(set(self.v1)) 

189 if len(common) > 0: 

190 raise ValueError( 

191 "Sets v1 and v2 have common nodes (forbidden): %r." % common) 

192 for a, b in edges: 

193 if a in v0 and b in v1: 

194 continue 

195 if a in v1 and b in v0: 

196 continue 

197 if b in v1: 

198 # One operator is missing one input. 

199 # We add one. 

200 self.v0[a] = BiGraph.A('ERROR') 

201 continue 

202 raise ValueError( 

203 "Edges (%r, %r) not found among the vertices." % (a, b)) 

204 

205 def __str__(self): 

206 """ 

207 usual 

208 """ 

209 return "%s(%d v., %d v., %d edges)" % ( 

210 self.__class__.__name__, len(self.v0), 

211 len(self.v1), len(self.edges)) 

212 

213 def __iter__(self): 

214 """ 

215 Iterates over all vertices and edges. 

216 It produces 3-uples: 

217 

218 * 0, name, A: vertices in *v0* 

219 * 1, name, A: vertices in *v1* 

220 * -1, name, A: edges 

221 """ 

222 for k, v in self.v0.items(): 

223 yield 0, k, v 

224 for k, v in self.v1.items(): 

225 yield 1, k, v 

226 for k, v in self.edges.items(): 

227 yield -1, k, v 

228 

229 def __getitem__(self, key): 

230 """ 

231 Returns a vertex is key is a string or an edge 

232 if it is a tuple. 

233 

234 :param key: vertex or edge name 

235 :return: value 

236 """ 

237 if isinstance(key, tuple): 

238 return self.edges[key] 

239 if key in self.v0: 

240 return self.v0[key] 

241 return self.v1[key] 

242 

243 def order_vertices(self): 

244 """ 

245 Orders the vertices from the input to the output. 

246 

247 :return: dictionary `{vertex name: order}` 

248 """ 

249 order = {} 

250 for v in self.v0: 

251 order[v] = 0 

252 for v in self.v1: 

253 order[v] = 0 

254 modif = 1 

255 n_iter = 0 

256 while modif > 0: 

257 modif = 0 

258 for a, b in self.edges: 

259 if order[b] <= order[a]: 

260 order[b] = order[a] + 1 

261 modif += 1 

262 n_iter += 1 

263 if n_iter > len(order): 

264 break 

265 if modif > 0: 

266 raise RuntimeError( 

267 "The graph has a cycle.\n%s" % pprint.pformat( 

268 self.edges)) 

269 return order 

270 

271 def adjacency_matrix(self): 

272 """ 

273 Builds an adjacency matrix. 

274 

275 :return: matrix, list of row vertices, list of column vertices 

276 """ 

277 order = self.order_vertices() 

278 ord_v0 = [(v, k) for k, v in order.items() if k in self.v0] 

279 ord_v1 = [(v, k) for k, v in order.items() if k in self.v1] 

280 ord_v0.sort() 

281 ord_v1.sort() 

282 row = [b for a, b in ord_v0] 

283 col = [b for a, b in ord_v1] 

284 row_id = {b: i for i, b in enumerate(row)} 

285 col_id = {b: i for i, b in enumerate(col)} 

286 matrix = numpy.zeros((len(row), len(col)), dtype=numpy.int32) 

287 for a, b in self.edges: 

288 if a in row_id: 

289 matrix[row_id[a], col_id[b]] = 1 

290 else: 

291 matrix[row_id[b], col_id[a]] = 1 

292 return matrix, row, col 

293 

294 def display_structure(self, grid=5, distance=5): 

295 """ 

296 Creates a display structure which contains 

297 all the necessary steps to display a graph. 

298 

299 :param grid: align text to this grid 

300 :param distance: distance to the text 

301 :return: instance of @see cl AdjacencyGraphDisplay 

302 """ 

303 def adjust(c, way): 

304 if way == 1: 

305 d = grid * ((c + distance * 2 - (grid // 2 + 1)) // grid) 

306 else: 

307 d = -grid * ((-c + distance * 2 - (grid // 2 + 1)) // grid) 

308 return d 

309 

310 matrix, row, col = self.adjacency_matrix() 

311 row_id = {b: i for i, b in enumerate(row)} 

312 col_id = {b: i for i, b in enumerate(col)} 

313 

314 interval_y_min = numpy.zeros((matrix.shape[0], ), dtype=numpy.int32) 

315 interval_y_max = numpy.zeros((matrix.shape[0], ), dtype=numpy.int32) 

316 interval_x_min = numpy.zeros((matrix.shape[1], ), dtype=numpy.int32) 

317 interval_x_max = numpy.zeros((matrix.shape[1], ), dtype=numpy.int32) 

318 interval_y_min[:] = max(matrix.shape) 

319 interval_x_min[:] = max(matrix.shape) 

320 

321 graph = AdjacencyGraphDisplay() 

322 for key, value in self.edges.items(): 

323 if key[0] in row_id: 

324 y = row_id[key[0]] 

325 x = col_id[key[1]] 

326 else: 

327 x = col_id[key[0]] 

328 y = row_id[key[1]] 

329 graph.add(x, y, 'cross', label=value.kind, orientation=(1, 0)) 

330 if x < interval_y_min[y]: 

331 interval_y_min[y] = x 

332 if x > interval_y_max[y]: 

333 interval_y_max[y] = x 

334 if y < interval_x_min[x]: 

335 interval_x_min[x] = y 

336 if y > interval_x_max[x]: 

337 interval_x_max[x] = y 

338 

339 for k, v in self.v0.items(): 

340 y = row_id[k] 

341 x = adjust(interval_y_min[y], -1) 

342 graph.add(x, y, 'text', label=v.kind, orientation=(-1, 0)) 

343 x = adjust(interval_y_max[y], 1) 

344 graph.add(x, y, 'text', label=k, orientation=(1, 0)) 

345 

346 for k, v in self.v1.items(): 

347 x = col_id[k] 

348 y = adjust(interval_x_min[x], -1) 

349 graph.add(x, y, 'text', label=v.kind, orientation=(0, -1)) 

350 y = adjust(interval_x_max[x], 1) 

351 graph.add(x, y, 'text', label=k, orientation=(0, 1)) 

352 

353 return graph 

354 

355 def order(self): 

356 """ 

357 Order nodes. Depth first. 

358 Returns a sequence of keys of mixed *v1*, *v2*. 

359 """ 

360 # Creates forwards nodes. 

361 forwards = {} 

362 backwards = {} 

363 for k in self.v0: 

364 forwards[k] = [] 

365 backwards[k] = [] 

366 for k in self.v1: 

367 forwards[k] = [] 

368 backwards[k] = [] 

369 modif = True 

370 while modif: 

371 modif = False 

372 for edge in self.edges: 

373 a, b = edge 

374 if b not in forwards[a]: 

375 forwards[a].append(b) 

376 modif = True 

377 if a not in backwards[b]: 

378 backwards[b].append(a) 

379 modif = True 

380 

381 # roots 

382 roots = [b for b, backs in backwards.items() if len(backs) == 0] 

383 if len(roots) == 0: 

384 raise RuntimeError( # pragma: no cover 

385 "This graph has cycles. Not allowed.") 

386 

387 # ordering 

388 order = {} 

389 stack = roots 

390 while len(stack) > 0: 

391 node = stack.pop() 

392 order[node] = len(order) 

393 w = forwards[node] 

394 if len(w) == 0: 

395 continue 

396 last = w.pop() 

397 stack.append(last) 

398 

399 return order 

400 

401 def summarize(self): 

402 """ 

403 Creates a text summary of the graph. 

404 """ 

405 order = self.order() 

406 keys = [(o, k) for k, o in order.items()] 

407 keys.sort() 

408 

409 rows = [] 

410 for _, k in keys: 

411 if k in self.v1: 

412 rows.append(str(self.v1[k])) 

413 return "\n".join(rows) 

414 

415 @staticmethod 

416 def _onnx2bigraph_basic(model_onnx, recursive=False): 

417 """ 

418 Implements graph type `'basic'` for function 

419 @see fn onnx2bigraph. 

420 """ 

421 

422 if recursive: 

423 raise NotImplementedError( # pragma: no cover 

424 "Option recursive=True is not implemented yet.") 

425 v0 = {} 

426 v1 = {} 

427 edges = {} 

428 

429 # inputs 

430 for i, o in enumerate(model_onnx.graph.input): 

431 v0[o.name] = BiGraph.A('Input-%d' % i) 

432 for i, o in enumerate(model_onnx.graph.output): 

433 v0[o.name] = BiGraph.A('Output-%d' % i) 

434 for o in model_onnx.graph.initializer: 

435 v0[o.name] = BiGraph.A('Init') 

436 for n in model_onnx.graph.node: 

437 nname = n.name if len(n.name) > 0 else "id%d" % id(n) 

438 v1[nname] = BiGraph.A(n.op_type) 

439 for i, o in enumerate(n.input): 

440 c = str(i) if i < 10 else "+" 

441 nname = n.name if len(n.name) > 0 else "id%d" % id(n) 

442 edges[o, nname] = BiGraph.A('I%s' % c) 

443 for i, o in enumerate(n.output): 

444 c = str(i) if i < 10 else "+" 

445 if o not in v0: 

446 v0[o] = BiGraph.A('inout') 

447 nname = n.name if len(n.name) > 0 else "id%d" % id(n) 

448 edges[nname, o] = BiGraph.A('O%s' % c) 

449 

450 return BiGraph(v0, v1, edges) 

451 

452 @staticmethod 

453 def _onnx2bigraph_simplified(model_onnx, recursive=False): 

454 """ 

455 Implements graph type `'simplified'` for function 

456 @see fn onnx2bigraph. 

457 """ 

458 if recursive: 

459 raise NotImplementedError( # pragma: no cover 

460 "Option recursive=True is not implemented yet.") 

461 v0 = {} 

462 v1 = {} 

463 edges = {} 

464 

465 # inputs 

466 for o in model_onnx.graph.input: 

467 v0["I%d" % len(v0)] = BiGraph.B( 

468 'In', make_hash_bytes(o.type.SerializeToString(), 2), o.name) 

469 for o in model_onnx.graph.output: 

470 v0["O%d" % len(v0)] = BiGraph.B( 

471 'Ou', make_hash_bytes(o.type.SerializeToString(), 2), o.name) 

472 for o in model_onnx.graph.initializer: 

473 v0["C%d" % len(v0)] = BiGraph.B( 

474 'Cs', make_hash_bytes(o.raw_data, 10), o.name) 

475 

476 names_v0 = {v.onnx_name: k for k, v in v0.items()} 

477 

478 for n in model_onnx.graph.node: 

479 key_node = "N%d" % len(v1) 

480 if len(n.attribute) > 0: 

481 ats = [] 

482 for at in n.attribute: 

483 ats.append(at.SerializeToString()) 

484 ct = make_hash_bytes(b"".join(ats), 10) 

485 else: 

486 ct = "" 

487 v1[key_node] = BiGraph.B( 

488 n.op_type, ct, n.name) 

489 for o in n.input: 

490 key_in = names_v0[o] 

491 edges[key_in, key_node] = BiGraph.A('I') 

492 for o in n.output: 

493 if o not in names_v0: 

494 key = "R%d" % len(v0) 

495 v0[key] = BiGraph.B('Re', n.op_type, o) 

496 names_v0[o] = key 

497 edges[key_node, key] = BiGraph.A('O') 

498 

499 return BiGraph(v0, v1, edges) 

500 

501 @staticmethod 

502 def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print): 

503 """ 

504 Computes a distance between two ONNX graphs. They must not 

505 be too big otherwise this function might take for ever. 

506 The function relies on package :epkg:`mlstatpy`. 

507 

508 :param onx1: first graph (ONNX graph or model file name) 

509 :param onx2: second graph (ONNX graph or model file name) 

510 :param verbose: verbosity 

511 :param fLOG: logging function 

512 :return: distance and differences 

513 

514 .. warning:: 

515 

516 This is very experimental and very slow. 

517 

518 .. versionadded:: 0.7 

519 """ 

520 from mlstatpy.graph.graph_distance import GraphDistance 

521 

522 if isinstance(onx1, str): 

523 onx1 = onnx.load(onx1) 

524 if isinstance(onx2, str): 

525 onx2 = onnx.load(onx2) 

526 

527 def make_hash(init): 

528 return make_hash_bytes(init.raw_data) 

529 

530 def build_graph(onx): 

531 edges = [] 

532 labels = {} 

533 for node in onx.graph.node: 

534 if len(node.name) == 0: 

535 name = str(id(node)) 

536 else: 

537 name = node.name 

538 for i in node.input: 

539 edges.append((i, name)) 

540 for p, i in enumerate(node.output): 

541 edges.append((name, i)) 

542 labels[i] = "%s:%d" % (node.op_type, p) 

543 labels[name] = node.op_type 

544 for init in onx.graph.initializer: 

545 labels[init.name] = make_hash(init) 

546 

547 g = GraphDistance(edges, vertex_label=labels) 

548 return g 

549 

550 g1 = build_graph(onx1) 

551 g2 = build_graph(onx2) 

552 

553 dist, gdist = g1.distance_matching_graphs_paths( 

554 g2, verbose=verbose, fLOG=fLOG, use_min=False) 

555 return dist, gdist 

556 

557 

558def onnx2bigraph(model_onnx, recursive=False, graph_type='basic'): 

559 """ 

560 Converts an ONNX graph into a graph representation, 

561 edges, vertices. 

562 

563 :param model_onnx: ONNX graph 

564 :param recursive: dig into subgraphs too 

565 :param graph_type: kind of graph it creates 

566 :return: see @cl BiGraph 

567 

568 About *graph_type*: 

569 

570 * `'basic'`: basic graph structure, it returns an instance 

571 of type @see cl BiGraph. The structure keeps the original 

572 names. 

573 * `'simplified'`: simplifed graph structure, names are removed 

574 as they could be prevent the algorithm to find any matching. 

575 

576 .. exref:: 

577 :title: Displays an ONNX graph as text 

578 

579 The function uses an adjacency matrix of the graph. 

580 Results are displayed by rows, operator by columns. 

581 Results kinds are shows on the left, 

582 their names on the right. Operator types are displayed 

583 on the top, their names on the bottom. 

584 

585 .. runpython:: 

586 :showcode: 

587 

588 import numpy 

589 from mlprodict.onnx_conv import to_onnx 

590 from mlprodict import __max_supported_opset__ as opv 

591 from mlprodict.tools.graphs import onnx2bigraph 

592 from mlprodict.npy.xop import loadop 

593 

594 OnnxAdd, OnnxSub = loadop('Add', 'Sub') 

595 

596 idi = numpy.identity(2).astype(numpy.float32) 

597 A = OnnxAdd('X', idi, op_version=opv) 

598 B = OnnxSub(A, 'W', output_names=['Y'], op_version=opv) 

599 onx = B.to_onnx({'X': idi, 'W': idi}) 

600 bigraph = onnx2bigraph(onx) 

601 graph = bigraph.display_structure() 

602 text = graph.to_text() 

603 print(text) 

604 

605 .. versionadded:: 0.7 

606 """ 

607 if graph_type == 'basic': 

608 return BiGraph._onnx2bigraph_basic( 

609 model_onnx, recursive=recursive) 

610 if graph_type == 'simplified': 

611 return BiGraph._onnx2bigraph_simplified( 

612 model_onnx, recursive=recursive) 

613 raise ValueError( 

614 "Unknown value for graph_type=%r." % graph_type) 

615 

616 

617def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print): 

618 """ 

619 Computes a distance between two ONNX graphs. They must not 

620 be too big otherwise this function might take for ever. 

621 The function relies on package :epkg:`mlstatpy`. 

622 

623 :param onx1: first graph (ONNX graph or model file name) 

624 :param onx2: second graph (ONNX graph or model file name) 

625 :param verbose: verbosity 

626 :param fLOG: logging function 

627 :return: distance and differences 

628 

629 .. warning:: 

630 

631 This is very experimental and very slow. 

632 

633 .. versionadded:: 0.7 

634 """ 

635 return BiGraph.onnx_graph_distance(onx1, onx2, verbose=verbose, fLOG=fLOG)