Coverage for mlprodict/onnx_tools/onnx_manipulations.py: 95%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

356 statements  

1""" 

2@file 

3@brief Implements a class able to compute the predictions 

4from on an :epkg:`ONNX` model. 

5""" 

6import hashlib 

7from onnx import helper, shape_inference 

8from .onnx2py_helper import guess_proto_dtype, from_array 

9from .optim import onnx_remove_node_unused 

10 

11 

12def enumerate_model_node_outputs(model, add_node=False, order=False): 

13 """ 

14 Enumerates all the nodes of a model. 

15 

16 :param model: :epkg:`ONNX` graph 

17 :param add_node: if False, the function enumerates 

18 all output names from every node, otherwise, it 

19 enumerates tuple (output name, node) 

20 :param order: goes through outputs following the graph order 

21 :return: enumerator 

22 """ 

23 if not hasattr(model, "graph"): 

24 raise TypeError( # pragma: no cover 

25 "Parameter model is not an ONNX model but " 

26 "{}".format(type(model))) 

27 if order: 

28 edges = [] 

29 order = {} 

30 node_names = {} 

31 for inp in model.graph.input: 

32 order[0, inp.name] = 0 

33 for node in model.graph.node: 

34 order[1, node.name] = 0 

35 for i in node.input: 

36 edges.append(('in', i, node.name)) 

37 for o in node.output: 

38 edges.append(('out', o, node.name)) 

39 node_names[o] = node 

40 order[0, o] = 0 

41 

42 modif = 1 

43 while modif > 0: 

44 modif = 0 

45 for kind, data_name, node_name in edges: 

46 if kind == 'in': 

47 if (0, data_name) not in order: 

48 continue 

49 if order[0, data_name] + 1 > order[1, node_name]: 

50 modif += 1 

51 order[1, node_name] = order[0, data_name] + 1 

52 else: 

53 if order[1, node_name] + 1 > order[0, data_name]: 

54 modif += 1 

55 order[0, data_name] = order[1, node_name] + 1 

56 

57 orders = [(v, k) for k, v in order.items()] 

58 orders.sort() 

59 

60 for _, k in orders: 

61 if k[0] == 1: 

62 continue 

63 out = k[1] 

64 if out not in node_names: 

65 continue 

66 yield (out, node_names[out]) if add_node else out 

67 else: 

68 for node in model.graph.node: 

69 for out in node.output: 

70 yield (out, node) if add_node else out 

71 

72 

73def select_model_inputs_outputs(model, outputs=None, inputs=None, 

74 infer_shapes=False, overwrite=None, 

75 remove_unused=True, 

76 verbose=0, fLOG=None): 

77 """ 

78 Takes a model and changes its outputs. 

79 

80 :param model: :epkg:`ONNX` model 

81 :param inputs: new inputs, same ones if None 

82 :param outputs: new outputs, same ones if None 

83 :param infer_shapes: infer inputs and outputs shapes 

84 :param overwrite: overwrite type and shapes for 

85 inputs or outputs, *overwrite* is a 

86 dictionary `{'name': (numpy dtype, shape)}` 

87 :param remove_unused: remove unused nodes from the graph 

88 :param verbose: display information while converting 

89 :param fLOG: logging function 

90 :return: modified model 

91 

92 The function removes unneeded nodes. 

93 

94 .. exref:: 

95 :title: Change ONNX model inputs 

96 

97 The following exampels shows how to change the inputs of model 

98 to bypass the first nodes. Shape inferences fails to determine 

99 the new inputs type. They need to be overwritten. 

100 `verbose=1, fLOG=print` shows the number of deleted nodes. 

101 

102 :: 

103 

104 import onnx 

105 from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs 

106 

107 onx = onnx.load(path) 

108 onx2 = select_model_inputs_outputs( 

109 onx, inputs=["SentenceTokenizer/SentencepieceTokenizeOp:0", 

110 "SentenceTokenizer/SentencepieceTokenizeOp:1"], 

111 infer_shapes=True, verbose=1, fLOG=print, 

112 overwrite={'SentenceTokenizer/SentencepieceTokenizeOp:0': (numpy.int32, None), 

113 'SentenceTokenizer/SentencepieceTokenizeOp:1': (numpy.int64, None)}) 

114 onnx.save(onx2, path2) 

115 

116 .. versionchanged:: 0.6 

117 Supports the case where inputs are changed. 

118 

119 .. versionchanged:: 0.7 

120 Parameter *remove_unused* was added. Unused are removed by default. 

121 """ 

122 if inputs is not None and not isinstance(inputs, list): 

123 inputs = [inputs] 

124 if outputs is not None and not isinstance(outputs, list): 

125 outputs = [outputs] 

126 if inputs is None: 

127 inputs = [i.name for i in model.graph.input] 

128 if outputs is None: 

129 outputs = [o.name for o in model.graph.output] 

130 

131 mark_var = {} 

132 for out in enumerate_model_node_outputs(model): 

133 mark_var[out] = 0 

134 for inp in inputs: 

135 mark_var[inp] = 0 

136 for out in outputs: 

137 if out not in mark_var: 

138 raise ValueError( # pragma: no cover 

139 "Output '{}' not found in model.".format(out)) 

140 mark_var[out] = 1 

141 

142 nodes = model.graph.node[::-1] 

143 mark_op = {} 

144 for node in nodes: 

145 mark_op[node.name] = 0 

146 

147 # We mark all the nodes we need to keep. 

148 nb = 1 

149 while nb > 0: 

150 nb = 0 

151 for node in nodes: 

152 if mark_op[node.name] == 1: 

153 continue 

154 mod = False 

155 for out in node.output: 

156 if mark_var[out] == 1: 

157 mark_op[node.name] = 1 

158 mod = True 

159 break 

160 if not mod: 

161 continue 

162 

163 nb += 1 

164 for inp in node.input: 

165 if inp in inputs: 

166 continue 

167 if mark_var.get(inp, 0) == 1: 

168 continue 

169 mark_var[inp] = 1 

170 nb += 1 

171 

172 # All nodes verifies mark_op[node.name] == 1 

173 keep_nodes = [node for node in nodes if mark_op[node.name] == 1] 

174 

175 known_shapes = {} 

176 if infer_shapes: 

177 shapes = shape_inference.infer_shapes(model) 

178 for shape in shapes.graph.value_info: # pylint: disable=E1101 

179 known_shapes[shape.name] = shape.type 

180 for shape in shapes.graph.input: # pylint: disable=E1101 

181 known_shapes[shape.name] = shape.type 

182 for shape in shapes.graph.output: # pylint: disable=E1101 

183 known_shapes[shape.name] = shape.type 

184 else: 

185 for shape in model.graph.input: # pylint: disable=E1101 

186 known_shapes[shape.name] = shape.type 

187 for shape in model.graph.output: # pylint: disable=E1101 

188 known_shapes[shape.name] = shape.type 

189 

190 var_in = [] 

191 for name in inputs: 

192 if overwrite is not None and name in overwrite: 

193 dtype, shape = overwrite[name] 

194 proto_dtype = guess_proto_dtype(dtype) 

195 value_info = helper.make_tensor_value_info( 

196 name, proto_dtype, shape) 

197 elif name in known_shapes: 

198 info = known_shapes[name].tensor_type 

199 proto_dtype = info.elem_type 

200 if proto_dtype == 0: 

201 value_info = helper.ValueInfoProto() 

202 value_info.name = name 

203 else: 

204 shape = [getattr(d, 'dim_value', None) for d in info.shape.dim] 

205 if len(shape) == 0: 

206 shape = None 

207 else: 

208 shape = [None if s == 0 else s for s in shape] 

209 value_info = helper.make_tensor_value_info( 

210 name, proto_dtype, shape) 

211 else: 

212 value_info = helper.ValueInfoProto() 

213 value_info.name = name 

214 var_in.append(value_info) 

215 

216 var_out = [] 

217 for name in outputs: 

218 if overwrite is not None and name in overwrite: 

219 dtype, shape = overwrite[name] 

220 proto_dtype = guess_proto_dtype(dtype) 

221 value_info = helper.make_tensor_value_info( 

222 name, proto_dtype, shape) 

223 elif name in known_shapes: 

224 info = known_shapes[name].tensor_type 

225 proto_dtype = info.elem_type 

226 if proto_dtype == 0: 

227 value_info = helper.ValueInfoProto() 

228 value_info.name = name 

229 else: 

230 shape = [getattr(d, 'dim_value', None) for d in info.shape.dim] 

231 if len(shape) == 0: 

232 shape = None 

233 else: 

234 shape = [None if s == 0 else s for s in shape] 

235 value_info = helper.make_tensor_value_info( 

236 name, proto_dtype, shape) 

237 else: 

238 value_info = helper.ValueInfoProto() 

239 value_info.name = name 

240 var_out.append(value_info) 

241 

242 if verbose > 0 and fLOG is not None: # pragma: no cover 

243 fLOG("[select_model_inputs_outputs] nodes %r --> %r" % ( 

244 len(model.graph.node), len(keep_nodes))) 

245 fLOG("[select_model_inputs_outputs] inputs: %r" % var_in) 

246 fLOG("[select_model_inputs_outputs] inputs: %r" % var_out) 

247 

248 graph = helper.make_graph(keep_nodes, model.graph.name, var_in, 

249 var_out, model.graph.initializer) 

250 onnx_model = helper.make_model(graph) 

251 onnx_model.ir_version = model.ir_version 

252 onnx_model.producer_name = model.producer_name 

253 onnx_model.producer_version = model.producer_version 

254 onnx_model.domain = model.domain 

255 onnx_model.model_version = model.model_version 

256 onnx_model.doc_string = model.doc_string 

257 if len(model.metadata_props) > 0: # pragma: no cover 

258 values = {p.key: p.value for p in model.metadata_props} 

259 helper.set_model_props(onnx_model, values) 

260 

261 del onnx_model.opset_import[:] # pylint: disable=E1101 

262 for oimp in model.opset_import: 

263 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

264 op_set.domain = oimp.domain 

265 op_set.version = oimp.version 

266 

267 # remove unused nodes 

268 if remove_unused: 

269 onnx_model = onnx_remove_node_unused(onnx_model, recursive=False) 

270 

271 return onnx_model 

272 

273 

274def overwrite_opset(model, new_opset): 

275 """ 

276 Overwrites the main opset in an ONNX file. 

277 Does not change any node definition. 

278 

279 :param model: ONNX model 

280 :param new_opset: new opset 

281 :return: ONNX model 

282 """ 

283 graph = helper.make_graph( 

284 model.graph.node, model.graph.name, model.graph.input, 

285 model.graph.output, model.graph.initializer) 

286 onnx_model = helper.make_model(graph) 

287 onnx_model.ir_version = model.ir_version 

288 onnx_model.producer_name = model.producer_name 

289 onnx_model.producer_version = model.producer_version 

290 onnx_model.domain = model.domain 

291 onnx_model.model_version = model.model_version 

292 onnx_model.doc_string = model.doc_string 

293 if len(model.metadata_props) > 0: # pragma: no cover 

294 values = {p.key: p.value for p in model.metadata_props} 

295 helper.set_model_props(onnx_model, values) 

296 

297 del onnx_model.opset_import[:] # pylint: disable=E1101 

298 for oimp in model.opset_import: 

299 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

300 if oimp.domain == '': 

301 op_set.domain = oimp.domain 

302 op_set.version = new_opset 

303 else: 

304 op_set.domain = oimp.domain 

305 op_set.version = oimp.version 

306 return onnx_model 

307 

308 

309def hash_onnx_object(obj, max_size): 

310 """ 

311 Hash the content of an object. 

312 """ 

313 m = hashlib.sha256() 

314 if hasattr(obj, 'op_type'): 

315 # An operator. 

316 m.update(obj.op_type.encode('ascii')) 

317 m.update(str(len(obj.input)).encode('ascii')) 

318 m.update(str(len(obj.output)).encode('ascii')) 

319 if hasattr(obj, 'attribute'): 

320 for att in obj.attribute: 

321 m.update(att.name.encode('ascii')) 

322 m.update(att.SerializeToString()) 

323 else: 

324 # An initializer. 

325 name = obj.name 

326 docf = obj.doc_string 

327 obj.name = '' 

328 obj.doc_string = '' 

329 try: 

330 m.update(obj.SerializeToString()) 

331 except AttributeError as e: # pragma: no cover 

332 raise RuntimeError( 

333 "Unable to hash object type %r, value=%r." 

334 "" % (type(obj), obj)) from e 

335 finally: 

336 obj.name = name 

337 obj.doc_string = docf 

338 

339 content = m.hexdigest() 

340 if len(content) > max_size: 

341 content = content[:max_size] 

342 return content.upper() 

343 

344 

345def onnx_rename_names(model, strategy='simple', recursive=True, 

346 verbose=0, fLOG=print, 

347 counts=None, replace=None, taken=None): 

348 """ 

349 Renames all names except the inputs and outputs. 

350 

351 :param model: onnx model 

352 :param strategy: two strategies are implemented, see below 

353 :param recursive: walk through subgraphs 

354 :param verbose: verbose, if positive, reports on all changed names 

355 :param fLOG: logging function 

356 :param counts: used for recursion 

357 :param replace: used for recursion, it can be also used to 

358 to fix some replacements 

359 :param taken: used for recursion 

360 :return: onnx model (the model is modified in place) 

361 

362 Strategies: 

363 

364 * `'simple'`: use a letter `n` for node, `r`, `i` for initializer, 

365 this letter is followed by a number 

366 * `'type'`: the name depends on the node type and content, 

367 the hash is kept as small as possible 

368 """ 

369 counts = counts or {'init': 0, 'node': 0, 'result': 0} 

370 replace = replace or {} 

371 taken = taken or set() 

372 graph = model.graph if hasattr(model, 'graph') else model 

373 

374 for obj in graph.input: 

375 replace[obj.name] = obj.name 

376 for obj in graph.output: 

377 replace[obj.name] = obj.name 

378 

379 def _check_name_simple(prefix): 

380 if prefix not in replace: 

381 return prefix 

382 c = 1 

383 final = "%s_%d" % (prefix, c) 

384 while final in taken: 

385 c += 1 

386 final = "%s_%d" % (prefix, c) 

387 taken.add(final) 

388 return final 

389 

390 def _check_name_type(obj, prefix): 

391 c = 2 

392 hash = hash_onnx_object(obj, c) 

393 final = "%s_%s" % (prefix, hash) 

394 while final in taken: 

395 c += 2 

396 hash = hash_onnx_object(obj, c) 

397 final = "%s_%s" % (prefix, hash) 

398 taken.add(final) 

399 return final 

400 

401 def get_name_init(init): 

402 if init.name in replace: 

403 return replace[init.name] 

404 if strategy == 'simple': 

405 name = _check_name_simple('i%d' % counts['init']) 

406 counts['init'] += 1 

407 replace[init.name] = name 

408 if verbose > 0 and fLOG is not None: 

409 fLOG('[onnx_rename_names] init: %r -> %r' % (init.name, name)) 

410 return name 

411 if strategy == 'type': 

412 name = _check_name_type(init, 'i') 

413 counts['init'] += 1 

414 replace[init.name] = name 

415 if verbose > 0 and fLOG is not None: 

416 fLOG('[onnx_rename_names] init: %r -> %r' % (init.name, name)) 

417 return name 

418 raise ValueError( # pragma: no cover 

419 "Unknown strategy %r." % strategy) 

420 

421 def get_name_node(node): 

422 node_name = 'node_%s_%d' % (node.name, id(node)) 

423 if node_name in replace: 

424 return replace[node_name] 

425 if strategy == 'simple': 

426 name = _check_name_simple('n%d' % counts['node']) 

427 counts['node'] += 1 

428 replace[node_name] = name 

429 if verbose > 0 and fLOG is not None: 

430 fLOG('[onnx_rename_names] node: %r -> %r' % (node_name, name)) 

431 return name 

432 if strategy == 'type': 

433 name = _check_name_type(node, 'n') 

434 counts['node'] += 1 

435 replace[node_name] = name 

436 if verbose > 0 and fLOG is not None: 

437 fLOG('[onnx_rename_names] node: %r -> %r' % (node_name, name)) 

438 return name 

439 raise ValueError( # pragma: no cover 

440 "Unknown strategy %r." % strategy) 

441 

442 def get_name_result(node, i, name, suffix): 

443 if name in replace: 

444 return replace[name] 

445 if strategy == 'simple': 

446 new_name = _check_name_simple('r%d' % counts['result']) 

447 counts['result'] += 1 

448 replace[name] = new_name 

449 if verbose > 0 and fLOG is not None: 

450 fLOG('[onnx_rename_names] result: %r -> %r' % (name, new_name)) 

451 return new_name 

452 if strategy == 'type': 

453 new_name = _check_name_type(node, 'r%s%d' % (suffix, i)) 

454 counts['result'] += 1 

455 replace[name] = new_name 

456 if verbose > 0 and fLOG is not None: 

457 fLOG('[onnx_rename_names] result: %r -> %r' % (name, new_name)) 

458 return new_name 

459 raise ValueError( # pragma: no cover 

460 "Unknown strategy %r." % strategy) 

461 

462 def get_name_input(node, i): 

463 return get_name_result(node, i, node.input[i], 'i') 

464 

465 def get_name_output(node, i): 

466 return get_name_result(node, i, node.output[i], 'o') 

467 

468 for init in graph.initializer: 

469 init.name = get_name_init(init) 

470 

471 for node in graph.node: 

472 node.name = get_name_node(node) 

473 for i in range(len(node.input)): # pylint: disable=C0200 

474 node.input[i] = get_name_input(node, i) 

475 for i in range(len(node.output)): # pylint: disable=C0200 

476 node.output[i] = get_name_output(node, i) 

477 if not recursive or node.op_type not in {'Scan', 'If', 'Loop'}: 

478 continue 

479 # recursion 

480 for att in node.attribute: 

481 if att.name not in {'if_branch', 'else_branch', 'body'}: 

482 continue 

483 onnx_rename_names( 

484 att.g, strategy=strategy, fLOG=fLOG, verbose=verbose, 

485 counts=counts, replace=replace, taken=taken) 

486 

487 return model 

488 

489 

490def insert_results_into_onnx(model, results, as_parameter=True, suffix='_DBG', 

491 param_name=None, node_type='DEBUG', 

492 domain='DEBUG', domain_opset=1): 

493 """ 

494 Inserts results into an ONNX graph to produce an extended 

495 ONNX graph. It can saved and looked into with a tool such as 

496 :epkg:`netron`. 

497 

498 :param model: ONNX graph 

499 :param results: results to be added in a dictionary 

500 :param as_parameter: add new nodes with results as one parameter 

501 (True) or as initializer (False) 

502 :param suffix: suffix to add to new results 

503 :param param_name: name of the parameter to add 

504 (by default the result name), it can be a function 

505 `param_name(reult_name) -> parameter_name` 

506 :param node_type: type of the new node 

507 :param domain: domain the new node 

508 :param domain_opset: opset for *domain* 

509 :return: new ONNX graph 

510 

511 See method :meth:`OnnxInference.run2onnx 

512 <mlprodict.onnxrt.onnx_inference.OnnxInference.run2onnx>` 

513 to see a graph this function produces. 

514 

515 .. image:: debug.png 

516 

517 .. versionadded:: 0.7 

518 """ 

519 inputs = list(model.graph.input) 

520 outputs = list(model.graph.output) 

521 inits = list(model.graph.initializer) 

522 nodes = {id(n): n for n in model.graph.node} 

523 order = {id(n): i for i, n in enumerate(model.graph.node)} 

524 nodes_copy = {} 

525 

526 names_init = set(init.name for init in inits) 

527 names_input = set(init.name for init in inputs) 

528 names_output = {} 

529 for node in nodes.values(): 

530 for i, o in enumerate(node.output): 

531 names_output[o] = (i, node) 

532 

533 for k, v in results.items(): 

534 if k in names_init: 

535 # initializer are not inserted again 

536 continue 

537 if k in names_input: 

538 # inputs are added as 

539 raise NotImplementedError( 

540 "Unable to add debug information on input %r." % k) 

541 

542 if k not in names_output: 

543 raise RuntimeError( 

544 "Unable to find result %r in the ONNX graph. Available=" 

545 "[%s]." % (k, ", ".join(sorted(names_output)))) 

546 

547 index, node = names_output[k] 

548 new_name = k + suffix 

549 

550 if id(node) not in nodes_copy: 

551 new_node = helper.make_node( 

552 node.op_type, list(node.input), list(node.output), 

553 domain=node.domain if node.domain else None, 

554 name=node.name + suffix) 

555 new_node.attribute.extend(node.attribute) # pylint: disable=E1101 

556 nodes_copy[id(node)] = new_node 

557 order[id(new_node)] = order[id(node)] 

558 new_node = nodes_copy[id(node)] 

559 new_node.output[index] = new_name 

560 

561 if as_parameter: 

562 pname = k if param_name is None else param_name(k) 

563 atts = {pname: from_array(v, name=pname)} 

564 inserted_node = helper.make_node( 

565 node_type, [new_name], [k], domain=domain, 

566 **atts) 

567 else: 

568 pname = k if param_name is None else param_name(k) 

569 pname += suffix + 'i' 

570 inserted_node = helper.make_node( 

571 node_type, [new_name, pname], [k], domain=domain) 

572 inits.append(from_array(v, name=pname)) 

573 

574 order[id(inserted_node)] = order[id(node)] + 1. / (index + 2) 

575 nodes[id(inserted_node)] = inserted_node 

576 

577 new_nodes = [(order[id(n)], n) 

578 for n in nodes.values() if id(n) not in nodes_copy] 

579 new_nodes.extend((order[id(n)], n) for n in nodes_copy.values()) 

580 new_nodes = [n[1] for n in sorted(new_nodes)] 

581 

582 graph = helper.make_graph(new_nodes, model.graph.name, inputs, 

583 outputs, inits) 

584 onnx_model = helper.make_model(graph) 

585 onnx_model.ir_version = model.ir_version 

586 onnx_model.producer_name = model.producer_name 

587 onnx_model.producer_version = model.producer_version 

588 onnx_model.domain = model.domain 

589 onnx_model.model_version = model.model_version 

590 onnx_model.doc_string = model.doc_string 

591 if len(model.metadata_props) > 0: # pragma: no cover 

592 values = {p.key: p.value for p in model.metadata_props} 

593 helper.set_model_props(onnx_model, values) 

594 

595 del onnx_model.opset_import[:] # pylint: disable=E1101 

596 for oimp in model.opset_import: 

597 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

598 op_set.domain = oimp.domain 

599 op_set.version = oimp.version 

600 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

601 op_set.domain = domain 

602 op_set.version = domain_opset 

603 return onnx_model