Coverage for mlprodict/npy/numpy_onnx_impl_body.py: 74%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Design to implement graph as parameter.
5.. versionadded:: 0.8
6"""
7import logging
8import numpy
9from .onnx_variable import OnnxVar
10from .xop import loadop
13logger = logging.getLogger('xop')
16class AttributeGraph:
17 """
18 Class wrapping a function to make it simple as
19 a parameter.
21 :param fct: function taking the list of inputs defined
22 as @see cl OnnxVar, the function returns an @see cl OnnxVar
23 :param inputs: list of input as @see cl OnnxVar
25 .. versionadded:: 0.8
26 """
28 def __init__(self, fct, *inputs):
29 logger.debug('AttributeGraph(%r, %d in)', type(fct), len(inputs))
30 if isinstance(fct, numpy.ndarray) and len(inputs) == 0:
31 self.cst = fct
32 fct = None
33 else:
34 self.cst = None
35 self.fct = fct
36 self.inputs = inputs
37 self.alg_ = None
39 def __repr__(self):
40 "usual"
41 return "%s(...)" % self.__class__.__name__
43 def _graph_guess_dtype(self, i, var):
44 """
45 Guesses the graph inputs.
47 :param i: attribute index (integer)
48 :param var: the input (@see cl OnnxVar)
49 :return: input type
50 """
51 dtype = var._guess_dtype(None)
52 if dtype is None:
53 dtype = numpy.float32
55 input_name = 'graph_%d_%d' % (id(self), i)
56 return OnnxVar(input_name, dtype=dtype)
58 def to_algebra(self, op_version=None):
59 """
60 Converts the variable into an operator.
61 """
62 if self.alg_ is not None:
63 return self.alg_
65 logger.debug('AttributeGraph.to_algebra(op_version=%r)',
66 op_version)
67 if self.cst is not None:
68 OnnxIdentity = loadop('Identity')
69 self.alg_ = OnnxIdentity(self.cst, op_version=op_version)
70 self.alg_inputs_ = None
71 logger.debug('AttributeGraph.to_algebra:end:1:%r', type(self.alg_))
72 return self.alg_
74 new_inputs = [self._graph_guess_dtype(i, inp)
75 for i, inp in enumerate(self.inputs)]
76 self.alg_inputs_ = new_inputs
77 vars = [v[1] for v in new_inputs]
78 var = self.fct(*vars)
79 if not isinstance(var, OnnxVar):
80 raise RuntimeError( # pragma: no cover
81 "var is not from type OnnxVar but %r." % type(var))
83 self.alg_ = var.to_algebra(op_version=op_version)
84 logger.debug('AttributeGraph.to_algebra:end:2:%r', type(self.alg_))
85 return self.alg_
88class OnnxVarGraph(OnnxVar):
89 """
90 Overloads @see cl OnnxVar to handle graph attribute.
92 :param inputs: variable name or object
93 :param op: :epkg:`ONNX` operator
94 :param select_output: if multiple output are returned by
95 ONNX operator *op*, it takes only one specifed by this
96 argument
97 :param dtype: specifies the type of the variable
98 held by this class (*op* is None) in that case
99 :param fields: list of attributes with the graph type
100 :param kwargs: addition argument to give operator *op*
102 .. versionadded:: 0.8
103 """
105 def __init__(self, *inputs, op=None, select_output=None,
106 dtype=None, **kwargs):
107 OnnxVar.__init__(
108 self, *inputs, op=op, select_output=select_output,
109 dtype=dtype, **kwargs)
111 def to_algebra(self, op_version=None):
112 """
113 Converts the variable into an operator.
114 """
115 if self.alg_ is not None:
116 return self.alg_
118 logger.debug('OnnxVarGraph.to_algebra(op_version=%r)',
119 op_version)
120 # Conversion of graph attributes from InputGraph
121 # ONNX graph.
122 updates = dict()
123 self.alg_hidden_var_ = {}
124 self.alg_hidden_var_inputs = {}
125 for att, var in self.onnx_op_kwargs.items():
126 if not isinstance(var, AttributeGraph):
127 continue
128 alg = var.to_algebra(op_version=op_version)
129 if var.alg_inputs_ is None:
130 onnx_inputs = []
131 else:
132 onnx_inputs = [i[0] for i in var.alg_inputs_]
133 onx = alg.to_onnx(onnx_inputs, target_opset=op_version)
134 updates[att] = onx.graph
135 self.alg_hidden_var_[id(var)] = var
136 self.alg_hidden_var_inputs[id(var)] = onnx_inputs
137 self.onnx_op_kwargs_before = {
138 k: self.onnx_op_kwargs[k] for k in updates}
139 self.onnx_op_kwargs.update(updates)
140 self.alg_ = OnnxVar.to_algebra(self, op_version=op_version)
141 logger.debug('OnnxVarGraph.to_algebra:end:%r', type(self.alg_))
142 return self.alg_
145class if_then_else(AttributeGraph):
146 """
147 Overloads class @see cl OnnxVarGraph.
148 """
150 def __init__(self, fct, *inputs):
151 AttributeGraph.__init__(self, fct, *inputs)