Coverage for mlprodict/onnx_tools/exports/skl2onnx_helper.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

45 statements  

1""" 

2@file 

3@brief Helpers to run examples created with :epkg:`sklearn-onnx`. 

4""" 

5from onnx import helper, TensorProto 

6 

7 

8def _copy_inout(inout, scope, new_name): 

9 shape = [s.dim_value for s in inout.type.tensor_type.shape.dim] 

10 value_info = helper.make_tensor_value_info( 

11 new_name, inout.type.tensor_type.elem_type, shape) 

12 return value_info 

13 

14 

15def _clean_variable_name(name, scope): 

16 return scope.get_unique_variable_name(name) 

17 

18 

19def _clean_operator_name(name, scope): 

20 return scope.get_unique_operator_name(name) 

21 

22 

23def _clean_initializer_name(name, scope): 

24 return scope.get_unique_variable_name(name) 

25 

26 

27def add_onnx_graph(scope, operator, container, onx): 

28 """ 

29 Adds a whole ONNX graph to an existing one following 

30 :epkg:`skl2onnx` API assuming this ONNX graph implements 

31 an `operator <http://onnx.ai/sklearn-onnx/api_summary.html? 

32 highlight=operator#skl2onnx.common._topology.Operator>`_. 

33 

34 :param scope: scope (to get unique names) 

35 :param operator: operator 

36 :param container: container 

37 :param onx: ONNX graph 

38 """ 

39 graph = onx.graph 

40 name_mapping = {} 

41 node_mapping = {} 

42 for node in graph.node: 

43 name = node.name 

44 if name is not None: 

45 node_mapping[node.name] = _clean_initializer_name( 

46 node.name, scope) 

47 for o in node.input: 

48 name_mapping[o] = _clean_variable_name(o, scope) 

49 for o in node.output: 

50 name_mapping[o] = _clean_variable_name(o, scope) 

51 for o in graph.initializer: 

52 name_mapping[o.name] = _clean_operator_name(o.name, scope) 

53 

54 inputs = [_copy_inout(o, scope, name_mapping[o.name]) 

55 for o in graph.input] 

56 outputs = [_copy_inout(o, scope, name_mapping[o.name]) 

57 for o in graph.output] 

58 

59 for inp, to in zip(operator.inputs, inputs): 

60 n = helper.make_node('Identity', [inp.onnx_name], [to.name], 

61 name=_clean_operator_name('Identity', scope)) 

62 container.nodes.append(n) 

63 

64 for inp, to in zip(outputs, operator.outputs): 

65 n = helper.make_node('Identity', [inp.name], [to.onnx_name], 

66 name=_clean_operator_name('Identity', scope)) 

67 container.nodes.append(n) 

68 

69 for node in graph.node: 

70 n = helper.make_node( 

71 node.op_type, 

72 [name_mapping[o] for o in node.input], 

73 [name_mapping[o] for o in node.output], 

74 name=node_mapping[node.name] if node.name else None, 

75 domain=node.domain if node.domain else None) 

76 n.attribute.extend(node.attribute) # pylint: disable=E1101 

77 container.nodes.append(n) 

78 

79 for o in graph.initializer: 

80 as_str = o.SerializeToString() 

81 tensor = TensorProto() 

82 tensor.ParseFromString(as_str) 

83 tensor.name = name_mapping[o.name] 

84 container.initializers.append(tensor) 

85 

86 # opset 

87 for oimp in onx.opset_import: 

88 container.node_domain_version_pair_sets.add( 

89 (oimp.domain, oimp.version))