Coverage for mlprodict/onnx_conv/operator_converters/conv_transfer_transformer.py: 96%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Converters for models from :epkg:`mlinsights`.
4"""
5from sklearn.base import is_classifier
6from skl2onnx import get_model_alias
7from skl2onnx.common._registration import (
8 get_shape_calculator, _converter_pool, _shape_calculator_pool)
9from skl2onnx._parse import _parse_sklearn
10from skl2onnx.common._apply_operation import apply_identity
11from skl2onnx.common._topology import Scope, Variable # pylint: disable=E0611,E0001
12from skl2onnx._supported_operators import sklearn_operator_name_map
15def parser_transfer_transformer(scope, model, inputs, custom_parsers=None):
16 """
17 Parser for :epkg:`TransferTransformer`.
18 """
19 if len(inputs) != 1:
20 raise RuntimeError( # pragma: no cover
21 "Only one input (not %d) is allowed for model type %r."
22 "" % (len(inputs), type(model)))
23 if custom_parsers is not None and model in custom_parsers:
24 return custom_parsers[model](
25 scope, model, inputs, custom_parsers=custom_parsers)
27 if model.method == 'predict_proba':
28 name = 'probabilities'
29 elif model.method == 'transform':
30 name = 'variable'
31 else:
32 raise NotImplementedError( # pragma: no cover
33 "Unable to defined the output for method='{}' and model='{}'."
34 "".format(model.method, model.__class__.__name__))
36 prob = scope.declare_local_variable(name, inputs[0].type.__class__())
37 alias = get_model_alias(type(model))
38 this_operator = scope.declare_local_operator(alias, model)
39 this_operator.inputs = inputs
40 this_operator.outputs.append(prob)
41 return this_operator.outputs
44def shape_calculator_transfer_transformer(operator):
45 """
46 Shape calculator for :epkg:`TransferTransformer`.
47 """
48 if len(operator.inputs) != 1:
49 raise RuntimeError( # pragma: no cover
50 "Only one input (not %d) is allowed for model %r."
51 "" % (len(operator.inputs), operator))
52 op = operator.raw_operator
53 alias = get_model_alias(type(op.estimator_))
54 calc = get_shape_calculator(alias)
56 options = (None if not hasattr(operator.scope, 'options')
57 else operator.scope.options)
58 if is_classifier(op.estimator_):
59 if options is None:
60 options = {}
61 options = {id(op.estimator_): {'zipmap': False}}
62 registered_models = dict(
63 conv=_converter_pool, shape=_shape_calculator_pool,
64 aliases=sklearn_operator_name_map)
65 scope = Scope('temp', options=options,
66 registered_models=registered_models)
67 inputs = [
68 Variable(v.onnx_name, v.onnx_name, type=v.type, scope=scope)
69 for v in operator.inputs]
70 res = _parse_sklearn(scope, op.estimator_, inputs)
71 this_operator = res[0]._parent
72 calc(this_operator)
74 if op.method == 'predict_proba':
75 operator.outputs[0].type = this_operator.outputs[1].type
76 elif op.method == 'transform':
77 operator.outputs[0].type = this_operator.outputs[0].type
78 else:
79 raise NotImplementedError( # pragma: no cover
80 "Unable to defined the output for method='{}' and model='{}'.".format(
81 op.method, op.__class__.__name__))
82 if len(operator.inputs) != 1:
83 raise RuntimeError( # pragma: no cover
84 "Only one input (not %d) is allowed for model %r."
85 "" % (len(operator.inputs), operator))
88def convert_transfer_transformer(scope, operator, container):
89 """
90 Converters for :epkg:`TransferTransformer`.
91 """
92 op = operator.raw_operator
94 opts = scope.get_options(op)
95 if opts is None:
96 opts = {}
97 if is_classifier(op.estimator_):
98 opts['zipmap'] = False
99 container.add_options(id(op.estimator_), opts)
100 scope.add_options(id(op.estimator_), opts)
102 outputs = _parse_sklearn(scope, op.estimator_, operator.inputs)
104 if op.method == 'predict_proba':
105 index = 1
106 elif op.method == 'transform':
107 index = 0
108 else:
109 raise NotImplementedError( # pragma: no cover
110 "Unable to defined the output for method='{}' and model='{}'."
111 "".format(op.method, op.__class__.__name__))
113 apply_identity(scope, outputs[index].onnx_name,
114 operator.outputs[0].full_name, container,
115 operator_name=scope.get_unique_operator_name("IdentityTT"))