Coverage for mlprodict/onnxrt/doc/doc_helper.py: 96%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Documentation helper.
4"""
5import keyword
6import textwrap
7import re
8from onnx.defs import OpSchema
11def type_mapping(name):
12 """
13 Mapping between types name and type integer value.
15 .. runpython::
16 :showcode:
17 :warningout: DeprecationWarning
19 from mlprodict.onnxrt.doc.doc_helper import type_mapping
20 import pprint
21 pprint.pprint(type_mapping(None))
22 print(type_mapping("INT"))
23 print(type_mapping(2))
24 """
25 di = dict(FLOAT=1, FLOATS=6, GRAPH=5, GRAPHS=10, INT=2,
26 INTS=7, STRING=3, STRINGS=8, TENSOR=4,
27 TENSORS=9, UNDEFINED=0, SPARSE_TENSOR=11)
28 if name is None:
29 return di
30 if isinstance(name, str):
31 return di[name]
32 rev = {v: k for k, v in di.items()}
33 return rev[name]
36def _get_doc_template():
38 from jinja2 import Template # delayed import
39 return Template(textwrap.dedent("""
40 {% for sch in schemas %}
42 {{format_name_with_domain(sch)}}
43 {{'=' * len(format_name_with_domain(sch))}}
45 {{process_documentation(sch.doc)}}
47 {% if sch.attributes %}
48 **Attributes**
50 {% for _, attr in sorted(sch.attributes.items()) %}* *{{attr.name}}*{%
51 if attr.required %} (required){% endif %}: {{
52 process_attribute_doc(attr.description)}} {%
53 if attr.default_value %} {{
54 process_default_value(attr.default_value)
55 }} ({{type_mapping(attr.type)}}){% endif %}
56 {% endfor %}
57 {% endif %}
59 {% if sch.inputs %}
60 **Inputs**
62 {% if sch.min_input != sch.max_input %}Between {{sch.min_input
63 }} and {{sch.max_input}} inputs.
64 {% endif %}
65 {% for ii, inp in enumerate(sch.inputs) %}
66 * *{{getname(inp, ii)}}*{{format_option(inp)}}{{inp.typeStr}}: {{
67 inp.description}}{% endfor %}
68 {% endif %}
70 {% if sch.outputs %}
71 **Outputs**
73 {% if sch.min_output != sch.max_output %}Between {{sch.min_output
74 }} and {{sch.max_output}} outputs.
75 {% endif %}
76 {% for ii, out in enumerate(sch.outputs) %}
77 * *{{getname(out, ii)}}*{{format_option(out)}}{{out.typeStr}}: {{
78 out.description}}{% endfor %}
79 {% endif %}
81 {% if sch.type_constraints %}
82 **Type Constraints**
84 {% for ii, type_constraint in enumerate(sch.type_constraints)
85 %}* {{getconstraint(type_constraint, ii)}}: {{
86 type_constraint.description}}
87 {% endfor %}
88 {% endif %}
90 **Version**
92 *Onnx name:* `{{sch.name}} <{{build_doc_url(sch)}}{{sch.name}}>`_
94 {% if sch.support_level == OpSchema.SupportType.EXPERIMENTAL %}
95 No versioning maintained for experimental ops.
96 {% else %}
97 This version of the operator has been {% if
98 sch.deprecated %}deprecated{% else %}available{% endif %} since
99 version {{sch.since_version}}{% if
100 sch.domain %} of domain {{sch.domain}}{% endif %}.
101 {% if len(sch.versions) > 1 %}
102 Other versions of this operator:
103 {% for v in sch.version[:-1] %} {{v}} {% endfor %}
104 {% endif %}
105 {% endif %}
107 **Runtime implementation:**
108 :class:`{{sch.name}}
109 <mlprodict.onnxrt.ops_cpu.op_{{change_style(sch.name)}}.{{sch.name}}>`
111 {% endfor %}
112 """))
115_template_operator = _get_doc_template()
118class NewOperatorSchema:
119 """
120 Defines a schema for operators added in this package
121 such as @see cl TreeEnsembleRegressorDouble.
122 """
124 def __init__(self, name):
125 self.name = name
126 self.domain = 'mlprodict'
129def change_style(name):
130 """
131 Switches from *AaBb* into *aa_bb*.
133 @param name name to convert
134 @return converted name
135 """
136 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
137 s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
138 return s2 if not keyword.iskeyword(s2) else s2 + "_"
141def get_rst_doc(op_name):
142 """
143 Returns a documentation in RST format
144 for all :epkg:`OnnxOperator`.
146 :param op_name: operator name of None for all
147 :return: string
149 The function relies on module :epkg:`jinja2` or replaces it
150 with a simple rendering if not present.
151 """
152 from jinja2.runtime import Undefined
153 from ..ops_cpu._op import _schemas
154 schemas = [_schemas.get(op_name, NewOperatorSchema(op_name))]
156 def format_name_with_domain(sch):
157 if sch.domain:
158 return '{} ({})'.format(sch.name, sch.domain)
159 return sch.name
161 def format_option(obj):
162 opts = []
163 if OpSchema.FormalParameterOption.Optional == obj.option:
164 opts.append('optional')
165 elif OpSchema.FormalParameterOption.Variadic == obj.option:
166 opts.append('variadic')
167 if getattr(obj, 'isHomogeneous', False):
168 opts.append('heterogeneous')
169 if opts:
170 return " (%s)" % ", ".join(opts)
171 return "" # pragma: no cover
173 def getconstraint(const, ii):
174 if const.type_param_str:
175 name = const.type_param_str
176 else:
177 name = str(ii) # pragma: no cover
178 if const.allowed_type_strs:
179 name += " " + ", ".join(const.allowed_type_strs)
180 return name
182 def getname(obj, i):
183 name = obj.name
184 if len(name) == 0:
185 return str(i) # pragma: no cover
186 return name
188 def process_documentation(doc):
189 if doc is None:
190 doc = '' # pragma: no cover
191 if isinstance(doc, Undefined):
192 doc = '' # pragma: no cover
193 if not isinstance(doc, str):
194 raise TypeError( # pragma: no cover
195 "Unexpected type {} for {}".format(type(doc), doc))
196 doc = textwrap.dedent(doc)
197 main_docs_url = "https://github.com/onnx/onnx/blob/master/"
198 rep = {
199 '[the doc](IR.md)': '`ONNX <{0}docs/IR.md>`_',
200 '[the doc](Broadcasting.md)':
201 '`Broadcasting in ONNX <{0}docs/Broadcasting.md>`_',
202 '<dl>': '',
203 '</dl>': '',
204 '<dt>': '* ',
205 '<dd>': ' ',
206 '</dt>': '',
207 '</dd>': '',
208 '<tt>': '``',
209 '</tt>': '``',
210 '<br>': '\n',
211 '```': '``',
212 }
213 for k, v in rep.items():
214 doc = doc.replace(k, v.format(main_docs_url))
215 move = 0
216 lines = []
217 for line in doc.split('\n'):
218 if line.startswith("```"):
219 if move > 0:
220 move -= 4
221 lines.append("\n")
222 else:
223 lines.append("::\n")
224 move += 4
225 elif move > 0:
226 lines.append(" " * move + line)
227 else:
228 lines.append(line)
229 return "\n".join(lines)
231 def process_attribute_doc(doc):
232 return doc.replace("<br>", " ")
234 def build_doc_url(sch):
235 doc_url = "https://github.com/onnx/onnx/blob/master/docs/Operators"
236 if "ml" in sch.domain:
237 doc_url += "-ml"
238 doc_url += ".md"
239 doc_url += "#"
240 if sch.domain not in (None, '', 'ai.onnx'):
241 doc_url += sch.domain + "."
242 return doc_url
244 def process_default_value(value):
245 if value is None:
246 return '' # pragma: no cover
247 res = []
248 for c in str(value):
249 if ((c >= 'A' and c <= 'Z') or (c >= 'a' and c <= 'z') or
250 (c >= '0' and c <= '9')):
251 res.append(c)
252 continue
253 if c in '[]-+(),.?':
254 res.append(c)
255 continue
256 if len(res) == 0:
257 return "*default value cannot be automatically retrieved*"
258 return "Default value is ``" + ''.join(res) + "``"
260 fnwd = format_name_with_domain
261 tmpl = _template_operator
263 docs = tmpl.render(schemas=schemas, OpSchema=OpSchema,
264 len=len, getattr=getattr, sorted=sorted,
265 format_option=format_option,
266 getconstraint=getconstraint,
267 getname=getname, enumerate=enumerate,
268 format_name_with_domain=fnwd,
269 process_documentation=process_documentation,
270 build_doc_url=build_doc_url, str=str,
271 type_mapping=type_mapping,
272 process_attribute_doc=process_attribute_doc,
273 process_default_value=process_default_value,
274 change_style=change_style)
275 return docs.replace(" Default value is ````", "")
278def debug_onnx_object(obj, depth=3):
279 """
280 ``__dict__`` is not in most of :epkg:`onnx` objects.
281 This function uses function *dir* to explore this object.
282 """
283 def iterable(o):
284 try:
285 iter(o)
286 return True
287 except TypeError:
288 return False
290 if depth <= 0:
291 return None
293 rows = [str(type(obj))]
294 if not isinstance(obj, (int, str, float, bool)):
296 for k in sorted(dir(obj)):
297 try:
298 val = getattr(obj, k)
299 sval = str(val).replace("\n", " ")
300 except (AttributeError, ValueError) as e: # pragma: no cover
301 sval = "ERRROR-" + str(e)
302 val = None
304 if 'method-wrapper' in sval or "built-in method" in sval:
305 continue
307 rows.append("- {}: {}".format(k, sval))
308 if k.startswith('__') and k.endswith('__'):
309 continue
310 if val is None:
311 continue
313 if isinstance(val, dict):
314 try:
315 sorted_list = list(sorted(val.items()))
316 except TypeError: # pragma: no cover
317 sorted_list = list(val.items())
318 for kk, vv in sorted_list:
319 rows.append(" - [%s]: %s" % (str(kk), str(vv)))
320 res = debug_onnx_object(vv, depth - 1)
321 if res is None:
322 continue
323 for line in res.split("\n"):
324 rows.append(" " + line)
325 elif iterable(val):
326 if all(map(lambda o: isinstance(o, (str, bytes)) and len(o) == 1, val)):
327 continue
328 for i, vv in enumerate(val):
329 rows.append(" - [%d]: %s" % (i, str(vv)))
330 res = debug_onnx_object(vv, depth - 1)
331 if res is None:
332 continue
333 for line in res.split("\n"):
334 rows.append(" " + line)
335 elif not callable(val):
336 res = debug_onnx_object(val, depth - 1)
337 if res is None:
338 continue
339 for line in res.split("\n"):
340 rows.append(" " + line)
342 return "\n".join(rows)
345def visual_rst_template():
346 """
347 Returns a :epkg:`jinja2` template to display DOT graph for each
348 converter from :epkg:`sklearn-onnx`.
350 .. runpython::
351 :showcode:
352 :warningout: DeprecationWarning
354 from mlprodict.onnxrt.doc.doc_helper import visual_rst_template
355 print(visual_rst_template())
356 """
357 return textwrap.dedent("""
359 .. _l-{{link}}:
361 {{ title }}
362 {{ '=' * len(title) }}
364 Fitted on a problem type *{{ kind }}*
365 (see :func:`find_suitable_problem
366 <mlprodict.onnxrt.validate.validate_problems.find_suitable_problem>`),
367 method `{{ method }}` matches output {{ output_index }}.
368 {{ optim_param }}
370 ::
372 {{ indent(model, " ") }}
374 {{ table }}
376 .. gdot::
378 {{ indent(dot, " ") }}
379 """)