Coverage for mlprodict/sklapi/onnx_pipeline.py: 95%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief A pipeline which serializes into ONNX steps by steps.
4"""
5import numpy
6from sklearn.base import clone
7from sklearn.pipeline import Pipeline, _fit_transform_one
8from sklearn.utils.validation import check_memory
9from sklearn.utils import _print_elapsed_time
10from ..onnx_conv import to_onnx
11from .onnx_transformer import OnnxTransformer
14class OnnxPipeline(Pipeline):
15 """
16 The pipeline overwrites method *fit*, it trains and converts
17 every steps into ONNX before training the next step
18 in order to minimize discrepencies. By default,
19 ONNX is using float and not double which is the default
20 for :epkg:`scikit-learn`. It may introduce discrepencies
21 when a non-continuous model (mathematical definition) such
22 as tree ensemble and part of the pipeline.
24 :param steps:
25 List of (name, transform) tuples (implementing fit/transform) that are
26 chained, in the order in which they are chained, with the last object
27 an estimator.
28 :param memory: str or object with the joblib.Memory interface, default=None
29 Used to cache the fitted transformers of the pipeline. By default,
30 no caching is performed. If a string is given, it is the path to
31 the caching directory. Enabling caching triggers a clone of
32 the transformers before fitting. Therefore, the transformer
33 instance given to the pipeline cannot be inspected
34 directly. Use the attribute ``named_steps`` or ``steps`` to
35 inspect estimators within the pipeline. Caching the
36 transformers is advantageous when fitting is time consuming.
37 :param verbose: bool, default=False
38 If True, the time elapsed while fitting each step will be printed as it
39 is completed.
40 :param output_name: string
41 requested output name or None to request all and
42 have method *transform* to store all of them in a dataframe
43 :param enforce_float32: boolean
44 :epkg:`onnxruntime` only supports *float32*,
45 :epkg:`scikit-learn` usually uses double floats, this parameter
46 ensures that every array of double floats is converted into
47 single floats
48 :param runtime: string, defined the runtime to use
49 as described in @see cl OnnxInference.
50 :param options: see @see fn to_onnx
51 :param white_op: see @see fn to_onnx
52 :param black_op: see @see fn to_onnx
53 :param final_types: see @see fn to_onnx
54 :param op_version: ONNX targeted opset
56 The class stores transformers before converting them into ONNX
57 in attributes ``raw_steps_``.
59 See notebook :ref:`onnxdiscrepenciesrst` to see it can
60 be used to reduce discrepencies after it was converted into
61 *ONNX*.
62 """
64 def __init__(self, steps, *, memory=None, verbose=False,
65 output_name=None, enforce_float32=True,
66 runtime='python', options=None,
67 white_op=None, black_op=None, final_types=None,
68 op_version=None):
69 self.output_name = output_name
70 self.enforce_float32 = enforce_float32
71 self.runtime = runtime
72 self.options = options
73 self.white_op = white_op
74 self.white_op = white_op
75 self.black_op = black_op
76 self.final_types = final_types
77 self.op_version = op_version
78 # The constructor calls _validate_step and it checks the value
79 # of black_op.
80 Pipeline.__init__(
81 self, steps, memory=memory, verbose=verbose)
83 def fit(self, X, y=None, **fit_params):
84 """
85 Fits the model, fits all the transforms one after the
86 other and transform the data, then fit the transformed
87 data using the final estimator.
89 :param X: iterable
90 Training data. Must fulfill input requirements of first step of the
91 pipeline.
92 :param y: iterable, default=None
93 Training targets. Must fulfill label requirements for all steps of
94 the pipeline.
95 :param fit_params: dict of string -> object
96 Parameters passed to the ``fit`` method of each step, where
97 each parameter name is prefixed such that parameter ``p`` for step
98 ``s`` has key ``s__p``.
99 :return: self, Pipeline, this estimator
100 """
101 fit_params_steps = self._check_fit_params(**fit_params)
102 Xt = self._fit(X, y, **fit_params_steps)
103 with _print_elapsed_time('OnnxPipeline',
104 self._log_message(len(self.steps) - 1)):
105 if self._final_estimator != 'passthrough':
106 fit_params_last_step = fit_params_steps[self.steps[-1][0]]
107 self._final_estimator.fit(Xt, y, **fit_params_last_step)
109 return self
111 def _fit(self, X, y=None, **fit_params_steps):
112 # shallow copy of steps - this should really be steps_
113 if hasattr(self, 'raw_steps_') and self.raw_steps_ is not None: # pylint: disable=E0203
114 # Let's reuse the previous training.
115 self.steps = list(self.raw_steps_) # pylint: disable=E0203
116 self.raw_steps_ = list(self.raw_steps_)
117 else:
118 self.steps = list(self.steps)
119 self.raw_steps_ = list(self.steps)
121 self._validate_steps()
122 # Setup the memory
123 memory = check_memory(self.memory)
125 fit_transform_one_cached = memory.cache(_fit_transform_one)
127 for (step_idx,
128 name,
129 transformer) in self._iter(with_final=False,
130 filter_passthrough=False):
131 if (transformer is None or transformer == 'passthrough'):
132 with _print_elapsed_time('Pipeline',
133 self._log_message(step_idx)):
134 continue
136 if hasattr(memory, 'location'):
137 # joblib >= 0.12
138 if memory.location is None:
139 # we do not clone when caching is disabled to
140 # preserve backward compatibility
141 cloned_transformer = transformer
142 else:
143 cloned_transformer = clone(transformer)
144 else:
145 cloned_transformer = clone(transformer)
147 # Fit or load from cache the current transformer
148 x_train = X
149 X, fitted_transformer = fit_transform_one_cached(
150 cloned_transformer, X, y, None,
151 message_clsname='Pipeline',
152 message=self._log_message(step_idx),
153 **fit_params_steps[name])
154 # Replace the transformer of the step with the fitted
155 # transformer. This is necessary when loading the transformer
156 # from the cache.
157 self.raw_steps_[step_idx] = (name, fitted_transformer)
158 self.steps[step_idx] = (
159 name, self._to_onnx(name, fitted_transformer, x_train))
160 return X
162 def _to_onnx(self, name, fitted_transformer, x_train, rewrite_ops=True,
163 verbose=0):
164 """
165 Converts a transformer into ONNX.
167 :param name: model name
168 :param fitted_transformer: fitted transformer
169 :param x_train: training dataset
170 :param rewrite_ops: use rewritten converters
171 :param verbose: display some information
172 :return: corresponding @see cl OnnxTransformer
173 """
174 if not isinstance(x_train, numpy.ndarray):
175 raise RuntimeError( # pragma: no cover
176 "The pipeline only handle numpy arrays not {}.".format(
177 type(x_train)))
178 atts = {'options', 'white_op', 'black_op', 'final_types'}
179 kwargs = {k: getattr(self, k) for k in atts}
180 if self.enforce_float32 or x_train.dtype != numpy.float64:
181 x_train = x_train.astype(numpy.float32)
182 if 'options' in kwargs:
183 kwargs['options'] = self._preprocess_options(
184 name, kwargs['options'])
185 kwargs['target_opset'] = self.op_version
186 onx = to_onnx(fitted_transformer, x_train,
187 rewrite_ops=rewrite_ops, verbose=verbose,
188 **kwargs)
189 if len(onx.graph.output) != 1:
190 raise RuntimeError( # pragma: no cover
191 "Only one output is allowed in the ONNX graph not %d. "
192 "Model=%r" % (len(onx.graph.output), fitted_transformer))
193 tr = OnnxTransformer(
194 onx.SerializeToString(), output_name=self.output_name,
195 enforce_float32=self.enforce_float32, runtime=self.runtime)
196 return tr.fit()
198 def _preprocess_options(self, name, options):
199 """
200 Preprocesses the options.
202 @param name option name
203 @param options conversion options
204 @return new options
205 """
206 if options is None:
207 return None
208 prefix = name + '__'
209 new_options = {}
210 for k, v in options.items():
211 if isinstance(k, str):
212 if k.startswith(prefix):
213 new_options[k[len(prefix):]] = v
214 else:
215 new_options[k] = v
216 return new_options