Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_unused.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

49 statements  

1""" 

2@file 

3@brief Optimisation of :epkg:`ONNX` graphs. 

4""" 

5from onnx import FunctionProto 

6from onnx.helper import make_graph, make_function 

7from ._onnx_optimisation_common import ( # pylint: disable=E0611 

8 _apply_optimisation_on_graph, _apply_remove_node_fct_node) 

9 

10 

11def onnx_remove_node_unused(onnx_model, recursive=True, debug_info=None, **options): 

12 """ 

13 Removes unused nodes of the graph. An unused node 

14 is not involved in the output computation. 

15 

16 @param onnx_model onnx model 

17 @param recursive looks into subgraphs 

18 @param debug_info debug information (private) 

19 @param options unused 

20 @return new onnx _model 

21 """ 

22 if debug_info is None: 

23 debug_info = [str(type(onnx_model)).rsplit( 

24 '.', maxsplit=1)[-1].strip("'>")] 

25 else: 

26 debug_info = (debug_info + 

27 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")]) 

28 

29 if hasattr(onnx_model, 'graph'): 

30 return _apply_optimisation_on_graph( 

31 onnx_remove_node_unused, onnx_model, 

32 recursive=recursive, debug_info=debug_info, 

33 **options) 

34 

35 graph = onnx_model 

36 is_function = isinstance(graph, FunctionProto) 

37 data = {} 

38 valid = {} 

39 edges = {} 

40 

41 if not is_function: 

42 for init in graph.initializer: 

43 data[init.name, 0] = init 

44 

45 for node in graph.node: 

46 data[node.name, 1] = node 

47 for inp in node.input: 

48 data[inp, 0] = node 

49 edges[(inp, 0), (node.name, 1)] = node 

50 for out in node.output: 

51 data[out, 0] = node 

52 edges[(node.name, 1), (out, 0)] = node 

53 

54 for out in graph.output: 

55 valid[out if is_function else out.name, 0] = True 

56 

57 modif = 1 

58 while modif > 0: 

59 modif = 0 

60 for e1, e2 in edges: # pylint: disable=E1141 

61 if valid.get(e2, False) and not valid.get(e1, False): 

62 valid[e1] = True 

63 modif += 1 

64 

65 new_nodes = [n for n in graph.node if (n.name, 1) in valid] 

66 if not is_function: 

67 new_inits = [n for n in graph.initializer if (n.name, 0) in valid] 

68 

69 if recursive: 

70 # Handles subgraphs. 

71 for i in range(len(new_nodes)): # pylint: disable=C0200 

72 node = new_nodes[i] 

73 if node is None or not (node.attribute): # pylint: disable=C0325 

74 continue 

75 new_nodes[i] = _apply_remove_node_fct_node( 

76 onnx_remove_node_unused, 

77 node, recursive=True, debug_info=debug_info + [node.name]) 

78 

79 # Finally create the new graph. 

80 nodes = list(filter(lambda n: n is not None, new_nodes)) 

81 if is_function: 

82 return make_function( 

83 onnx_model.domain, onnx_model.name, 

84 onnx_model.input, onnx_model.output, nodes, 

85 opset_imports=onnx_model.opset_import, 

86 attributes=onnx_model.attribute, 

87 doc_string=onnx_model.doc_string) 

88 graph = make_graph(nodes, onnx_model.name, 

89 onnx_model.input, onnx_model.output, 

90 new_inits) 

91 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 

92 return graph