Coverage for mlprodict/onnx_tools/exports/tf2onnx_helper.py: 85%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

355 statements  

1""" 

2@file 

3@brief Helpers to run examples created with function 

4@see fn export2tf2onnx. 

5""" 

6import collections 

7import inspect 

8import numpy 

9from onnx.numpy_helper import from_array 

10from onnx.helper import ( 

11 make_node, make_graph, make_model, set_model_props, make_tensor) 

12from onnx import AttributeProto 

13from ..onnx2py_helper import guess_dtype, guess_proto_dtype 

14from ..onnx_tools import ensure_topological_order 

15 

16 

17_make_name_id = 0 

18 

19 

20def make_tf2onnx_code(opset, name=None, op_type=None, domain='', 

21 inputs=None, outputs=None, attributes=None, 

22 used=None, context=None, mark_inits=None, indent=8, 

23 **unused): 

24 """ 

25 Converts an ONNX operators into :epkg:`tf2onnx` code. 

26 

27 :param opset: target opset for the conversion (usually unused) 

28 :param name: node name 

29 :param op_type: operator type 

30 :param domain: domain 

31 :param inputs: inputs 

32 :param outputs: outputs 

33 :param attributes: attributes 

34 :param used: dictionary `{k: v}`, 

35 list of nodes taking *k* as input 

36 :param context: whole context 

37 :param mark_inits: marks initializer as replaced 

38 :param indent: number of spaces to add on the second 

39 and following rows 

40 :return: code as str 

41 """ 

42 def simplify(name, kind, force=False): 

43 value = None 

44 if (used is not None and name in used and 

45 len(used[name]) == 1 and context is not None): 

46 inits = context['initializers_dict'] 

47 if name in inits: 

48 v = inits[name] 

49 if v.dtype == numpy.int64 and v.size < 10: 

50 value = v 

51 if name not in mark_inits: 

52 mark_inits[name] = [] 

53 mark_inits[name].append(v) 

54 

55 if value is None and force: 

56 inits = context['initializers_dict'] 

57 if name not in inits: 

58 raise RuntimeError( # pragma: no cover 

59 "Unable to find init %r in %r value=%r." % ( 

60 name, list(sorted(inits)), value)) 

61 value = inits[name] 

62 if kind == 'list': 

63 if value is None: 

64 return name 

65 if len(value.shape) == 0: 

66 return str(value) 

67 return str(list(value)) 

68 if kind == 'list_var': 

69 if value is None: 

70 return "varx[%r]" % name 

71 if len(value.shape) == 0: 

72 return str(value) 

73 return str(list(value)) 

74 raise NotImplementedError( 

75 "Unknown scenario to simplify (%r)." % kind) 

76 

77 rows = [] 

78 if op_type == 'Unsqueeze': 

79 if len(inputs) == 2: 

80 rows.append( 

81 "node = GraphBuilder(ctx).make_unsqueeze(" 

82 "{'data': varx[%r], 'axes': %s}, return_node=True)" 

83 "" % (inputs[0], simplify(inputs[1], 'list_var'))) 

84 else: 

85 raise NotImplementedError( # pragma: no cover 

86 "Unable to create code for operator %r (opset <= 12)" 

87 "." % op_type) 

88 elif op_type == 'Squeeze': 

89 if len(inputs) == 1: 

90 rows.append( 

91 "node = GraphBuilder(ctx).make_squeeze(" 

92 "{'data': varx[%r]}, return_node=True)" 

93 "" % (inputs[0], )) 

94 elif len(inputs) == 2: 

95 rows.append( 

96 "node = GraphBuilder(ctx).make_squeeze(" 

97 "{'data': varx[%r], 'axes': %s}, return_node=True)" 

98 "" % (inputs[0], simplify(inputs[1], 'list_var'))) 

99 else: 

100 raise NotImplementedError( # pragma: no cover 

101 "Unable to create code for operator %r (opset <= 12)" 

102 "." % op_type) 

103 elif op_type == 'Slice': 

104 atts = dict(zip(['starts', 'ends', 'axes', 'steps'], 

105 inputs[1:])) 

106 text = ", ".join("'%s': %s" % (k, simplify(v, 'list_var')) 

107 for k, v in atts.items()) 

108 if len(inputs) in (3, 4, 5): 

109 rows.append( 

110 "node = GraphBuilder(ctx).make_slice(" 

111 "{'data': varx[%r], %s}, return_node=True)" 

112 "" % (inputs[0], text)) 

113 else: 

114 raise NotImplementedError( # pragma: no cover 

115 "Unable to create code for operator %r (opset <= 12)" 

116 "." % op_type) 

117 else: 

118 if len(attributes) > 0: 

119 attributes_str = ", ".join("%s=%s" % (k, v) for k, v in attributes) 

120 attr = ", attr=dict(%s)" % attributes_str 

121 else: 

122 attr = "" 

123 rows.append( 

124 "inputs = [%s]" % ", ".join("varx[%r]" % n for n in inputs)) 

125 sdomain = '' if domain == '' else ("domain=%r, " % domain) 

126 rows.append( 

127 "node = ctx.make_node(%r, inputs=inputs%s, %s" 

128 "name=make_name(%r))" % ( 

129 op_type, attr, sdomain, name)) 

130 for i, n in enumerate(outputs): 

131 rows.append("varx[%r] = node.output[%d]" % (n, i)) 

132 if indent > 0: 

133 sind = " " * indent 

134 for i in range(1, len(rows)): 

135 rows[i] = sind + rows[i] 

136 return "\n".join(rows) 

137 

138 

139def make_name(name): 

140 "Creates a unique name." 

141 global _make_name_id # pylint: disable=W0603 

142 name = "%s_%d" % (name, _make_name_id) 

143 _make_name_id += 1 

144 return name 

145 

146 

147def get_max_value(np_dtype): 

148 "Returns the maximum value for a specific type." 

149 return numpy.iinfo(np_dtype).max 

150 

151 

152def make_sure(cond, msg, *args): 

153 "Raises an exception if cond is not verified." 

154 if not cond: 

155 raise RuntimeError(msg % tuple(args)) # pragma: no cover 

156 

157 

158def map_onnx_to_numpy_type(onnx_dtype): 

159 "Converts ONNX type into numpy type." 

160 if onnx_dtype is None: 

161 return numpy.float32 

162 return guess_dtype(onnx_dtype) 

163 

164 

165class tf_op: 

166 """ 

167 Decorator to register any new converter. 

168 :param name: type of the operator to rewrite 

169 :param domain: domain 

170 """ 

171 _OPSETS = collections.OrderedDict() 

172 

173 def __init__(self, name, domain='', **kwargs): 

174 if not isinstance(name, list): 

175 name = [name] 

176 self.names = name 

177 self.domain = domain 

178 self.kwargs = kwargs 

179 

180 def __call__(self, func): 

181 for ke, va in inspect.getmembers(func, inspect.ismethod): 

182 if ke.startswith("version_"): 

183 version = int(ke.replace("version_", "")) 

184 self._register_handler( 

185 va, version, self.names, self.domain, self.kwargs) 

186 return func 

187 

188 def _register_handler(self, func, version, names, domain, kwargs): 

189 opset = tf_op._OPSETS.get(domain) 

190 if not opset: 

191 opset = [] 

192 tf_op._OPSETS[domain] = opset 

193 while version >= len(opset): 

194 opset.append({}) 

195 opset_dict = opset[version] 

196 for name in names: 

197 opset_dict[name] = (func, kwargs) 

198 

199 

200class Tf2OnnxConvert: 

201 """ 

202 Applies the converter on an ONNX graph. 

203 

204 :param onnx_model: ONNX graph 

205 :param tf_op: class which register 

206 :param verbose: verbosity 

207 :param target_opset: targetted opsets 

208 """ 

209 

210 def __init__(self, onnx_model, _tf_op=None, verbose=None, 

211 target_opset=None, max_iter=5): 

212 self._onnx_model = onnx_model 

213 self._tf_op = _tf_op or tf_op 

214 self.verbose = verbose 

215 self.max_iter = max_iter 

216 if isinstance(target_opset, int): 

217 self.target_opsets = {'': target_opset} 

218 elif isinstance(target_opset, dict): 

219 self.target_opsets = target_opset 

220 elif target_opset is None: 

221 opsets = {} 

222 for oimp in onnx_model.opset_import: 

223 if oimp.domain == '': 

224 opsets[oimp.domain] = oimp.version 

225 opset = oimp.version 

226 else: 

227 opsets[oimp.domain] = opset 

228 self.target_opsets = opsets 

229 else: 

230 raise ValueError( # pragma: no cover 

231 "Unexepected value for target_opset=%r." % target_opset) 

232 self._names = {} 

233 for node in onnx_model.graph.node: 

234 self._names[node.name] = node 

235 for init in onnx_model.graph.initializer: 

236 self._names[init.name] = init 

237 # _forbidden_new_names contains current names and deleted names. 

238 self._forbidden_new_names = set(self._names) 

239 if '' in self.target_opsets: 

240 self.opset = self.target_opsets[''] 

241 if not hasattr(self, 'opset'): 

242 raise RuntimeError( # pragma: no cover 

243 "Attribute opset is missing, target_opset=%r." % target_opset) 

244 

245 def get_node_by_name(self, name): 

246 """ 

247 Retrieves a node by its name. 

248 

249 :param name: node name 

250 :return: node name 

251 """ 

252 if name not in self._names: 

253 raise RuntimeError( # pragma: no cover 

254 "Unable to find node name %r among %r." % ( 

255 name, ", ".join(sorted(self._names)))) 

256 return self._names[name] 

257 

258 def _add_node_name(self, obj): 

259 """ 

260 Registers an object in in the graph by its name. 

261 :param name: node or initializer 

262 """ 

263 if obj.name in self._forbidden_new_names: 

264 raise RuntimeError( # pragma: no cover 

265 "Name %r is already registered." % obj.name) 

266 self._names[obj.name] = obj 

267 self._forbidden_new_names.add(obj.name) 

268 

269 def make_node(self, op_type, inputs, attr=None, outputs=None, 

270 name=None, domain='', output_count=1, 

271 shapes=None, dtypes=None): 

272 """ 

273 Adds a node to the list of nodes. 

274 

275 :param op_type: operator type 

276 :param inputs: list of strings 

277 :param attr: dictionary of attributes 

278 :param outputs: None or list of strings 

279 :param output_count: used if outputs is None to guess 

280 the number of outputs of this node 

281 :param name: name of the node 

282 :param domain: domain 

283 :param shapes: unused 

284 :param dtypes: unused 

285 :return: created node 

286 """ 

287 if self.verbose: 

288 print( # pragma: no cover 

289 "[Tf2OnnxConvert.make_node] op_type=%r inputs=%r" % ( 

290 op_type, inputs)) 

291 

292 if attr is None: 

293 attr = {} 

294 if name is None: 

295 name = make_name(op_type) 

296 if name in self._names: 

297 raise RuntimeError( # pragma: no cover 

298 "Node name %r already exists in %r." % ( 

299 name, ", ".join(sorted(self._names)))) 

300 

301 if outputs is None: 

302 outputs = [(name + ":" + str(i)) for i in range(output_count)] 

303 

304 output_count = len(outputs) 

305 raw_attr = {} 

306 onnx_attrs = [] 

307 for a, v in attr.items(): 

308 if isinstance(v, AttributeProto): 

309 onnx_attrs.append(v) 

310 else: 

311 raw_attr[a] = v 

312 

313 onnx_node = make_node( 

314 op_type, inputs, outputs, name=name, domain=domain, **raw_attr) 

315 

316 self._add_node_name(onnx_node) 

317 return onnx_node 

318 

319 def make_const(self, name, np_val, skip_conversion=False, raw=True): 

320 """ 

321 Make a new constants in the graph. 

322 :param name: const node name, must be unique. 

323 :param np_val: value of type numpy ndarray. 

324 :param skip_conversion: 

325 bool, indicate whether this created node would be mapped 

326 during conversion 

327 :param raw: whether to store data at field of raw_data or the 

328 specific field according to its dtype 

329 :return: create initializer 

330 """ 

331 if name in self._names: 

332 raise RuntimeError( # pragma: no cover 

333 "Initializer name %r already exists in %r." % ( 

334 name, ", ".join(sorted(self._names)))) 

335 np_val_flat = np_val.flatten() 

336 is_bytes = (np_val.dtype == numpy.object and len(np_val_flat) > 0 and 

337 isinstance(np_val_flat[0], bytes)) 

338 if raw and not is_bytes: 

339 onnx_tensor = from_array(np_val, name) 

340 else: 

341 onnx_tensor = make_tensor( 

342 name, guess_proto_dtype(np_val.dtype), 

343 np_val.shape, np_val_flat, raw=False) 

344 

345 self._add_node_name(onnx_tensor) 

346 return onnx_tensor 

347 

348 def get_dtype(self, input_name): 

349 """ 

350 Returns the type of one node or None if unknown. 

351 :param input_name: result name 

352 :return: numpy dtype 

353 """ 

354 inputs = self._onnx_model.graph.input 

355 names = [_.name for _ in inputs] 

356 if input_name not in names: 

357 return None # pragma: no cover 

358 ind = names.index(input_name) 

359 return inputs[ind].type.tensor_type.elem_type 

360 

361 def replace_all_inputs(self, old_name, new_name): 

362 """ 

363 Every taking *old_name* as inputs will take *new_name* instead. 

364 Looks in the output as well but in that case, it creates an identity 

365 node to avoid changing an output name. 

366 :param old_name: name to replace 

367 :param new_name: new name 

368 :return: list of impacted nodes 

369 """ 

370 if self.verbose: 

371 print( # pragma: no cover 

372 "[Tf2OnnxConvert.replace_all_inputs] replace %r by %r" % ( 

373 old_name, new_name)) 

374 res = [] 

375 for node in self._names.values(): 

376 if not hasattr(node, 'input'): 

377 continue 

378 if old_name not in node.input: 

379 continue 

380 new_inputs = [new_name if i == old_name else i 

381 for i in node.input] 

382 node.input[:] = new_inputs[:] 

383 res.append(node) 

384 if self.verbose: 

385 print( # pragma: no cover 

386 "[Tf2OnnxConvert.replace_all_inputs] replace %r by %r in node %r" % ( 

387 old_name, new_name, node.name)) 

388 for o in self._onnx_model.graph.output: 

389 if o.name != old_name: 

390 continue 

391 n = self.make_node("Identity", [new_name], outputs=[old_name], 

392 name=make_name("IdOutputReplaced")) 

393 res.append(n) 

394 if self.verbose: 

395 print( # pragma: no cover 

396 "[Tf2OnnxConvert.replace_all_inputs] add id node from %r to %r " 

397 "with node %r." % ( 

398 old_name, new_name, n.name)) # pylint: disable=E1101 

399 if self.verbose: 

400 print( # pragma: no cover 

401 "[Tf2OnnxConvert.replace_all_inputs] end") 

402 return res 

403 

404 def remove_node(self, name): 

405 """ 

406 Removes a node name from the list. 

407 """ 

408 if name not in self._names: 

409 raise RuntimeError( # pragma: no cover 

410 "Unable to delete name %r because it does not exists." % name) 

411 del self._names[name] 

412 if self.verbose: 

413 print( # pragma: no cover 

414 "[Tf2OnnxConvert.remove_node] delete name %r" % name) 

415 

416 def get_shape(self, input_name): 

417 """ 

418 Returns the type of one node or None if unknown. 

419 :param input_name: result name 

420 :return: numpy dtype 

421 """ 

422 inputs = self._onnx_model.graph.input 

423 names = [_.name for _ in inputs] 

424 if input_name not in names: 

425 return None # pragma: no cover 

426 ind = names.index(input_name) 

427 dims = inputs[ind].type.tensor_type.shape.dim 

428 return tuple(dims) 

429 

430 def run(self): 

431 """ 

432 Calls the registered converters on the graph 

433 held by this instance. Returns the new onnx graph. 

434 

435 :return: ONNX graph 

436 """ 

437 if len(self._tf_op._OPSETS) == 0: 

438 raise RuntimeError( # pragma: no cover 

439 "No converter was registered.") 

440 if self.verbose: 

441 print("[Tf2OnnxConvert.run]") # pragma: no cover 

442 

443 done = {} 

444 modif = 1 

445 turn = 0 

446 while modif > 0 and turn < self.max_iter: 

447 modif = 0 

448 turn += 1 

449 # The converter may alter the current list of nodes, we freeze it. 

450 current_values = list(self._names.values()) 

451 for node in current_values: 

452 if not hasattr(node, 'domain'): 

453 # initializer 

454 continue 

455 if done.get(node.name, False): 

456 continue 

457 domain = node.domain 

458 if domain not in self._tf_op._OPSETS: 

459 continue 

460 

461 # look for a converter 

462 rews = self._tf_op._OPSETS[domain] 

463 target = min(self.target_opsets[domain], len(rews)) 

464 conv = None 

465 for i in range(len(rews) - 1, -1, -1): 

466 if node.op_type in rews[i]: 

467 conv = rews[i][node.op_type] 

468 break 

469 if conv is None: 

470 continue 

471 

472 # applies the converter 

473 if self.verbose: 

474 print( # pragma: no cover 

475 "[Tf2OnnxConvert.run] convert node type=%r opset=%r name=%r" 

476 "" % (node.op_type, target, node.name)) 

477 fct, kwargs = conv 

478 fct(self, node, target_opset=target, **kwargs) 

479 modif += 1 

480 

481 if turn >= self.max_iter: 

482 raise RuntimeError( # pragma: no cover 

483 "Too many iterations and no stable ONNX was reached, " 

484 "iter=%d\n%s" % (turn, str(self.make_model()))) 

485 return self.make_model() 

486 

487 def make_model(self): 

488 """ 

489 Produces the new ONNX graph with the updated sets of nodes. 

490 """ 

491 inputs = self._onnx_model.graph.input 

492 outputs = self._onnx_model.graph.output 

493 inits = [init[1] for init in sorted(self._names.items()) 

494 if not hasattr(init[1], 'domain')] 

495 nodes = [node[1] for node in sorted(self._names.items()) 

496 if hasattr(node[1], 'domain')] 

497 nodes = ensure_topological_order(inputs, inits, nodes) 

498 

499 if self.verbose: 

500 print( # pragma: no cover 

501 "[Tf2OnnxConvert.make_node] %d nodes %d inputs %d " 

502 "outputs %d initializers" 

503 "" % (len(nodes), len(inputs), len(outputs), len(inits))) 

504 graph = make_graph(nodes, self._onnx_model.graph.name, 

505 inputs, outputs, inits) 

506 onnx_model = make_model(graph) 

507 onnx_model.ir_version = self._onnx_model.ir_version 

508 onnx_model.producer_name = self._onnx_model.producer_name + "-mlprodict" 

509 onnx_model.producer_version = self._onnx_model.producer_version 

510 onnx_model.domain = self._onnx_model.domain 

511 onnx_model.model_version = self._onnx_model.model_version 

512 onnx_model.doc_string = self._onnx_model.doc_string 

513 metadata = {p.key: p.value for p in self._onnx_model.metadata_props} 

514 set_model_props(onnx_model, metadata) 

515 

516 # opsets 

517 del onnx_model.opset_import[:] # pylint: disable=E1101 

518 for dom, value in self.target_opsets.items(): 

519 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

520 op_set.domain = dom 

521 op_set.version = value 

522 return onnx_model 

523 

524 

525class GraphBuilder: 

526 """ 

527 Helpers to build graph. 

528 :param graph! 

529 """ 

530 

531 def __init__(self, graph): 

532 self._g = graph 

533 

534 @property 

535 def graph(self): 

536 "Returns the graph." 

537 return self._g 

538 

539 def make_slice(self, kwargs, name=None, shapes=None, dtypes=None, 

540 return_node=False): 

541 """ 

542 slice changes its schema at opset 10: it treats some 

543 attributes as dynamic input so this function has to process 

544 inputs according to graph's opset version 

545 to get "inputs" and "attr" to feed "make_node" 

546 kwargs: key could be `["data", "starts", "ends", 

547 "axes", "steps", "outputs"]`. 

548 """ 

549 outputs = kwargs.pop("outputs", None) 

550 

551 if self.graph.opset < 10: 

552 # "data" is string 

553 # "starts", "ends" and "axes" are attributes, 

554 # and "axes" is optional. 

555 data = kwargs.pop("data") # pragma: no cover 

556 starts = self._convert_to_attribute( # pragma: no cover 

557 kwargs.pop("starts")) 

558 ends = self._convert_to_attribute( # pragma: no cover 

559 kwargs.pop("ends")) 

560 axes = self._convert_to_attribute( # pragma: no cover 

561 kwargs.pop("axes", None), is_optional=True) 

562 attr = {"starts": starts, "ends": ends, 

563 "axes": axes} # pragma: no cover 

564 inputs = [data] # pragma: no cover 

565 else: 

566 # slice-10 has 3 required inputs "data", "starts", "ends"l 

567 # and 2 optional inputs "axes", "steps" 

568 # input sequence should be "data", "starts", "ends", 

569 # "axes", "steps" 

570 attr = {} 

571 data = kwargs.pop("data") 

572 starts = self._convert_to_input( 

573 kwargs.pop("starts"), "const_starts", dtype=numpy.int64) 

574 ends = self._convert_to_input( 

575 kwargs.pop("ends"), "const_ends", dtype=numpy.int64) 

576 axes = self._convert_to_input( 

577 kwargs.pop("axes", None), "const_axes", 

578 is_optional=True, dtype=numpy.int64) 

579 steps = self._convert_to_input( 

580 kwargs.pop("steps", None), "const_steps", 

581 is_optional=True, dtype=numpy.int64) 

582 inputs = [data, starts, ends, axes, steps] 

583 

584 # pro-process inputs and attr 

585 make_sure(not kwargs, "kwargs contains un-used key") 

586 

587 new_attr = {} 

588 for key, val in attr.items(): 

589 if val is not None: 

590 new_attr[key] = val 

591 attr = new_attr 

592 

593 for ind, val in enumerate(inputs): 

594 if val is None: 

595 inputs[ind] = "" # empty string means no connection in ONNX 

596 # remove tailing "" 

597 while inputs[-1] == "": 

598 inputs = inputs[:-1] 

599 

600 if self.graph.opset >= 10: 

601 dtype = self.graph.get_dtype(inputs[1]) 

602 for input_data in inputs[1:]: 

603 if input_data != "": 

604 make_sure(dtype == self.graph.get_dtype( 

605 input_data), "dtype should be same") 

606 

607 node = self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, 

608 name=name, outputs=outputs, shapes=shapes, 

609 dtypes=dtypes) 

610 if return_node: 

611 return node 

612 raise NotImplementedError( # pragma: no cover 

613 "return_node must be True") 

614 

615 def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None, 

616 return_node=False, op_name_scope=None): 

617 """ 

618 Squeeze changes its schema at opset 13: it treats axes as a dynamic input 

619 kwargs: key could be ["data", "axes"]. 

620 """ 

621 outputs = kwargs.pop("outputs", None) 

622 

623 if self.graph.opset < 13: 

624 data = kwargs.pop("data") 

625 axes = self._convert_to_attribute( 

626 kwargs.pop("axes", None), is_optional=True) 

627 attr = {"axes": axes} 

628 inputs = [data] 

629 else: 

630 data = kwargs.pop("data") 

631 axes = self._convert_to_input( 

632 kwargs.pop("axes", None), "const_axes", 

633 is_optional=True, dtype=numpy.int64) 

634 attr = {} 

635 inputs = [data, axes] 

636 

637 make_sure(not kwargs, "kwargs contains un-used key") 

638 

639 new_attr = {} 

640 for key, val in attr.items(): 

641 if val is not None: 

642 new_attr[key] = val 

643 attr = new_attr 

644 

645 for ind, val in enumerate(inputs): 

646 if val is None: 

647 inputs[ind] = "" # empty string means no connection in ONNX 

648 # remove tailing "" 

649 while inputs[-1] == "": 

650 inputs = inputs[:-1] 

651 

652 node = self.graph.make_node( 

653 op_type="Squeeze", inputs=inputs, attr=attr, name=name, 

654 outputs=outputs) 

655 if return_node: 

656 return node 

657 raise NotImplementedError( # pragma: no cover 

658 "return_node must be True") 

659 

660 def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None, 

661 return_node=False, op_name_scope=None): 

662 """ 

663 Unsqueeze changes its schema at opset 13: it treats axes as a dynamic input 

664 kwargs: key could be ["data", "axes"]. 

665 """ 

666 outputs = kwargs.pop("outputs", None) 

667 

668 if self.graph.opset < 13: 

669 data = kwargs.pop("data") # pragma: no cover 

670 axes = self._convert_to_attribute( # pragma: no cover 

671 kwargs.pop("axes", None), is_optional=True) 

672 attr = {"axes": axes} # pragma: no cover 

673 inputs = [data] # pragma: no cover 

674 else: 

675 data = kwargs.pop("data") 

676 axes = self._convert_to_input( 

677 kwargs.pop("axes", None), "const_axes", 

678 is_optional=True, dtype=numpy.int64) 

679 attr = {} 

680 inputs = [data, axes] 

681 

682 make_sure(not kwargs, "kwargs contains un-used key") 

683 

684 new_attr = {} 

685 for key, val in attr.items(): 

686 if val is not None: 

687 new_attr[key] = val 

688 attr = new_attr 

689 

690 for ind, val in enumerate(inputs): 

691 if val is None: 

692 inputs[ind] = "" # empty string means no connection in ONNX 

693 # remove tailing "" 

694 while inputs[-1] == "": 

695 inputs = inputs[:-1] 

696 

697 node = self.graph.make_node( 

698 op_type="Unsqueeze", inputs=inputs, attr=attr, name=name, 

699 outputs=outputs) 

700 if return_node: 

701 return node 

702 raise NotImplementedError( # pragma: no cover 

703 "return_node must be True") 

704 

705 def _convert_to_input(self, tensor, const_name, is_optional=False, dtype=None): 

706 """in ONNX, input shold come from node, so it must be a string""" 

707 if is_optional and tensor is None: 

708 return None 

709 

710 make_sure(tensor is not None, 

711 "input is required so it couldn't be None") 

712 

713 res = tensor 

714 if isinstance(tensor, list): 

715 res = self.graph.make_const( 

716 make_name(const_name), numpy.array(tensor, dtype)).name 

717 return res 

718 

719 def _convert_to_attribute(self, tensor, is_optional=False): 

720 if is_optional and tensor is None: 

721 return None 

722 

723 make_sure(tensor is not None, 

724 "input is required so it couldn't be None") 

725 

726 res = tensor 

727 if isinstance(tensor, str): 

728 const_node = self.graph.get_node_by_output(tensor) 

729 res = const_node.get_tensor_value(as_list=True) 

730 

731 make_sure(isinstance(res, list), 

732 "input is an attr, so a list is needed") 

733 

734 return res