Coverage for mlprodict/onnxrt/onnx_shape_inference.py: 99%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Runtime to infer shapes.
5.. versionadded:: 0.9
6"""
7import numpy
8from onnx import FunctionProto, ModelProto
9from onnx.numpy_helper import to_array
10from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
11from .ops_shape.shape_result import ShapeResult
12from .ops_shape.shape_container import ShapeContainer
13from .ops_shape import shape_dispatch
16class OnnxShapeInference:
17 """
18 Implements a micro runtime for ONNX graphs.
19 It does not implements all the operator types.
21 :param model_onnx: ONNX model
23 Other attributes:
25 * `known_shapes_`: shapes which can be inferred without any input
26 * `cache_`: keeps track of the function used to infer
27 the shapes
28 * `is_isfunction`: tells if the graph is a function or a model
30 .. runpython::
31 :showcode:
33 import pprint
34 import numpy
35 from mlprodict.onnxrt.onnx_shape_inference import OnnxShapeInference
36 from mlprodict.npy.xop_variable import Variable
37 from mlprodict.npy.xop import loadop
39 opset = 15
40 OnnxAdd = loadop('Add')
41 dtype = numpy.float32
43 cop = OnnxAdd('X', numpy.array(
44 [[1]], dtype=dtype), op_version=opset)
45 cop4 = OnnxAdd(cop, numpy.array([[2]], dtype=dtype),
46 output_names=['Y'])
47 vari = Variable('X', numpy.float32, [None, 3])
48 model_def = cop4.to_onnx([vari], run_shape=False)
49 rt = OnnxShapeInference(model_def)
50 out = rt.run()
51 pprint.pprint(out.get())
52 """
54 def __init__(self, model_onnx):
55 if not isinstance(model_onnx, (FunctionProto, ModelProto)):
56 raise TypeError( # pragma: no cover
57 "model_onnx is not from FunctionProto or ModelProto "
58 "%r." % type(model_onnx))
59 self.is_function = isinstance(model_onnx, FunctionProto)
60 self.model_onnx = model_onnx
61 self.cache_ = {}
62 self.known_shapes_ = self._run_empty()
64 @property
65 def input_names(self):
66 "Returns input names."
67 if self.is_function:
68 return list(self.model_onnx.input)
69 return [i.name for i in self.model_onnx.graph.input]
71 @property
72 def output_names(self):
73 "Returns output names."
74 if self.is_function:
75 return list(self.model_onnx.output)
76 return [i.name for i in self.model_onnx.graph.output]
78 def __repr__(self):
79 "Usual"
80 return "%s(...)" % self.__class__.__name__
82 @staticmethod
83 def _get_shape(obj, known_shapes=None, result_name=None):
84 if obj is None:
85 return [], None, False
86 dtype = TENSOR_TYPE_TO_NP_TYPE.get(
87 obj.type.tensor_type.elem_type, None)
88 shape = []
89 for dimi, d in enumerate(obj.type.tensor_type.shape.dim):
90 v = d.dim_value if d.dim_value > 0 else d.dim_param
91 if v in ('', None):
92 if known_shapes is None or result_name is None:
93 raise RuntimeError( # pragma: no cover
94 "known_shapes must be specified if "
95 "a dimension is not.")
96 v = known_shapes.get_new_name(v, result_name, dimi)
97 shape.append(v)
98 return shape, dtype, False
100 def _run_empty(self):
101 """
102 Computes shape and types of all results.
104 :return: all intermediates results and output as a dictionary
105 """
106 def get_obj(name, inputs):
107 if self.is_function:
108 return None
109 if inputs:
110 for o in self.model_onnx.graph.input:
111 if o.name == name:
112 return o
113 else:
114 for o in self.model_onnx.graph.output:
115 if o.name == name:
116 return o
117 return None
119 known_shapes = ShapeContainer()
120 if not self.is_function:
121 for init in self.model_onnx.graph.initializer:
122 mat = to_array(init)
123 known_shapes.update(init.name, ShapeResult(
124 init.name, mat.shape, mat.dtype, sparse=False))
126 for name in self.input_names:
127 if name in known_shapes:
128 raise NotImplementedError(
129 "Optional inputs are not implemented yet. "
130 "(name=%r)" % name)
131 shape, dtype, sparse = self._get_shape(
132 get_obj(name, True), known_shapes, result_name=name)
133 known_shapes.update(name, ShapeResult(
134 name, shape, dtype, sparse=sparse))
136 for name in self.output_names:
137 if name in known_shapes:
138 raise RuntimeError( # pragma: no cover
139 "Output %r is already present. Use Identity node."
140 "" % name)
141 shape, dtype, sparse = self._get_shape(
142 get_obj(name, False), known_shapes, result_name=name)
143 if dtype is None:
144 # The onnx graph was created with named outputs
145 # but with no type or shape.
146 continue
147 known_shapes.update(name, ShapeResult(
148 name, shape, dtype, sparse=sparse))
150 nodes = (
151 self.model_onnx.node if self.is_function
152 else self.model_onnx.graph.node)
153 cont = True
154 while cont:
155 cont = False
156 for node in nodes:
157 cont = cont or shape_dispatch(
158 self.cache_, known_shapes, node, rt_class=self.__class__)
159 return known_shapes
161 def run(self, inputs=None):
162 """
163 Runs shape inference and type given known inputs.
165 :param inputs: inputs
166 :return: all results
167 """
168 known_shapes = self.known_shapes_.copy(deep=True)
169 if inputs is None:
170 known_shapes.resolve()
171 return known_shapes
173 cont = False
174 for name, obj in inputs.items():
175 shape, dtype, sparse = (
176 obj.shape, obj.dtype, not isinstance(obj, numpy.ndarray))
177 cont = cont or known_shapes.update(
178 name, ShapeResult(name, shape, dtype, sparse=sparse))
180 nodes = (
181 self.model_onnx.node if self.is_function
182 else self.model_onnx.graph.node)
183 while cont:
184 cont = False
185 for node in nodes:
186 updated = shape_dispatch(
187 self.cache_, known_shapes, node, rt_class=self.__class__)
188 cont = cont or updated
189 known_shapes.resolve()
190 return known_shapes