Coverage for mlprodict/onnxrt/onnx_shape_inference.py: 99%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

90 statements  

1""" 

2@file 

3@brief Runtime to infer shapes. 

4 

5.. versionadded:: 0.9 

6""" 

7import numpy 

8from onnx import FunctionProto, ModelProto 

9from onnx.numpy_helper import to_array 

10from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 

11from .ops_shape.shape_result import ShapeResult 

12from .ops_shape.shape_container import ShapeContainer 

13from .ops_shape import shape_dispatch 

14 

15 

16class OnnxShapeInference: 

17 """ 

18 Implements a micro runtime for ONNX graphs. 

19 It does not implements all the operator types. 

20 

21 :param model_onnx: ONNX model 

22 

23 Other attributes: 

24 

25 * `known_shapes_`: shapes which can be inferred without any input 

26 * `cache_`: keeps track of the function used to infer 

27 the shapes 

28 * `is_isfunction`: tells if the graph is a function or a model 

29 

30 .. runpython:: 

31 :showcode: 

32 

33 import pprint 

34 import numpy 

35 from mlprodict.onnxrt.onnx_shape_inference import OnnxShapeInference 

36 from mlprodict.npy.xop_variable import Variable 

37 from mlprodict.npy.xop import loadop 

38 

39 opset = 15 

40 OnnxAdd = loadop('Add') 

41 dtype = numpy.float32 

42 

43 cop = OnnxAdd('X', numpy.array( 

44 [[1]], dtype=dtype), op_version=opset) 

45 cop4 = OnnxAdd(cop, numpy.array([[2]], dtype=dtype), 

46 output_names=['Y']) 

47 vari = Variable('X', numpy.float32, [None, 3]) 

48 model_def = cop4.to_onnx([vari], run_shape=False) 

49 rt = OnnxShapeInference(model_def) 

50 out = rt.run() 

51 pprint.pprint(out.get()) 

52 """ 

53 

54 def __init__(self, model_onnx): 

55 if not isinstance(model_onnx, (FunctionProto, ModelProto)): 

56 raise TypeError( # pragma: no cover 

57 "model_onnx is not from FunctionProto or ModelProto " 

58 "%r." % type(model_onnx)) 

59 self.is_function = isinstance(model_onnx, FunctionProto) 

60 self.model_onnx = model_onnx 

61 self.cache_ = {} 

62 self.known_shapes_ = self._run_empty() 

63 

64 @property 

65 def input_names(self): 

66 "Returns input names." 

67 if self.is_function: 

68 return list(self.model_onnx.input) 

69 return [i.name for i in self.model_onnx.graph.input] 

70 

71 @property 

72 def output_names(self): 

73 "Returns output names." 

74 if self.is_function: 

75 return list(self.model_onnx.output) 

76 return [i.name for i in self.model_onnx.graph.output] 

77 

78 def __repr__(self): 

79 "Usual" 

80 return "%s(...)" % self.__class__.__name__ 

81 

82 @staticmethod 

83 def _get_shape(obj, known_shapes=None, result_name=None): 

84 if obj is None: 

85 return [], None, False 

86 dtype = TENSOR_TYPE_TO_NP_TYPE.get( 

87 obj.type.tensor_type.elem_type, None) 

88 shape = [] 

89 for dimi, d in enumerate(obj.type.tensor_type.shape.dim): 

90 v = d.dim_value if d.dim_value > 0 else d.dim_param 

91 if v in ('', None): 

92 if known_shapes is None or result_name is None: 

93 raise RuntimeError( # pragma: no cover 

94 "known_shapes must be specified if " 

95 "a dimension is not.") 

96 v = known_shapes.get_new_name(v, result_name, dimi) 

97 shape.append(v) 

98 return shape, dtype, False 

99 

100 def _run_empty(self): 

101 """ 

102 Computes shape and types of all results. 

103 

104 :return: all intermediates results and output as a dictionary 

105 """ 

106 def get_obj(name, inputs): 

107 if self.is_function: 

108 return None 

109 if inputs: 

110 for o in self.model_onnx.graph.input: 

111 if o.name == name: 

112 return o 

113 else: 

114 for o in self.model_onnx.graph.output: 

115 if o.name == name: 

116 return o 

117 return None 

118 

119 known_shapes = ShapeContainer() 

120 if not self.is_function: 

121 for init in self.model_onnx.graph.initializer: 

122 mat = to_array(init) 

123 known_shapes.update(init.name, ShapeResult( 

124 init.name, mat.shape, mat.dtype, sparse=False)) 

125 

126 for name in self.input_names: 

127 if name in known_shapes: 

128 raise NotImplementedError( 

129 "Optional inputs are not implemented yet. " 

130 "(name=%r)" % name) 

131 shape, dtype, sparse = self._get_shape( 

132 get_obj(name, True), known_shapes, result_name=name) 

133 known_shapes.update(name, ShapeResult( 

134 name, shape, dtype, sparse=sparse)) 

135 

136 for name in self.output_names: 

137 if name in known_shapes: 

138 raise RuntimeError( # pragma: no cover 

139 "Output %r is already present. Use Identity node." 

140 "" % name) 

141 shape, dtype, sparse = self._get_shape( 

142 get_obj(name, False), known_shapes, result_name=name) 

143 if dtype is None: 

144 # The onnx graph was created with named outputs 

145 # but with no type or shape. 

146 continue 

147 known_shapes.update(name, ShapeResult( 

148 name, shape, dtype, sparse=sparse)) 

149 

150 nodes = ( 

151 self.model_onnx.node if self.is_function 

152 else self.model_onnx.graph.node) 

153 cont = True 

154 while cont: 

155 cont = False 

156 for node in nodes: 

157 cont = cont or shape_dispatch( 

158 self.cache_, known_shapes, node, rt_class=self.__class__) 

159 return known_shapes 

160 

161 def run(self, inputs=None): 

162 """ 

163 Runs shape inference and type given known inputs. 

164 

165 :param inputs: inputs 

166 :return: all results 

167 """ 

168 known_shapes = self.known_shapes_.copy(deep=True) 

169 if inputs is None: 

170 known_shapes.resolve() 

171 return known_shapes 

172 

173 cont = False 

174 for name, obj in inputs.items(): 

175 shape, dtype, sparse = ( 

176 obj.shape, obj.dtype, not isinstance(obj, numpy.ndarray)) 

177 cont = cont or known_shapes.update( 

178 name, ShapeResult(name, shape, dtype, sparse=sparse)) 

179 

180 nodes = ( 

181 self.model_onnx.node if self.is_function 

182 else self.model_onnx.graph.node) 

183 while cont: 

184 cont = False 

185 for node in nodes: 

186 updated = shape_dispatch( 

187 self.cache_, known_shapes, node, rt_class=self.__class__) 

188 cont = cont or updated 

189 known_shapes.resolve() 

190 return known_shapes