Coverage for mlprodict/npy/xop_auto.py: 94%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Automates the generation of operators for the
4documentation for the Xop API.
6.. versionadded:: 0.9
7"""
8import os
9import textwrap
10import importlib
11import inspect
12import onnx
13import onnx.defs
14from onnx.backend.test.case.base import _Exporter
15from onnx.defs import OpSchema
18def _get_doc_template():
19 try:
20 from jinja2 import Template
21 except ImportError: # pragma no cover
22 class Template:
23 "Docstring template"
25 def __init__(self, *args):
26 pass
28 def render(self, **context):
29 "render"
30 schemas = context['schemas']
31 rows = []
32 for sch in schemas:
33 doc = sch.doc or ''
34 name = sch.name
35 if name is None:
36 raise RuntimeError("An operator must have a name.")
37 rows.extend([name, "=" * len(name),
38 "", doc, ""])
39 return "\n".join(rows)
41 return Template(textwrap.dedent("""
42 {% for sch in schemas %}
44 .. tag-diff-insert.
46 .. _l-onnx-op{{sch.domain.lower().replace(".", "-")}}-{{sch.name.lower()}}-{{str(sch.since_version)}}:
48 {{format_name_with_domain(sch)}}
49 {{'=' * len(format_name_with_domain(sch))}}
51 **Version**
53 * **name**: `{{sch.name}} (GitHub) <{{build_doc_url(sch)}}{{sch.name}}>`_
54 * **domain**: **{% if sch.domain == '' %}main{% else %}{{sch.domain}}{% endif %}**
55 * **since_version**: **{{sch.since_version}}**
56 * **function**: {{sch.has_function}}
57 * **support_level**: {{sch.support_level}}
58 * **shape inference**: {{sch.has_type_and_shape_inference_function}}
60 {% if sch.support_level == OpSchema.SupportType.EXPERIMENTAL %}
61 No versioning maintained for experimental ops.
62 {% else %}
63 This version of the operator has been {% if
64 sch.deprecated %}deprecated{% else %}available{% endif %}
65 **since version {{sch.since_version}}{% if
66 sch.domain %} of domain {{sch.domain}}{% endif %}**.
67 {% if len(sch.versions) > 1 %}
68 Other versions of this operator:
69 {% for v in sch.version[:-1] %} {{v}} {% endfor %}
70 {% endif %}
71 {% endif %}
73 **Summary**
75 {{process_documentation(sch.doc)}}
77 {% if sch.attributes %}
78 **Attributes**
80 {% for _, attr in sorted(sch.attributes.items()) %}* **{{attr.name}}**{%
81 if attr.required %} (required){% endif %}:
82 {{text_wrap(attr.description, 2)}} {%
83 if attr.default_value %}{{clean_default_value(attr.default_value)}}{%
84 endif %}
85 {% endfor %}
86 {% endif %}
88 {% if sch.inputs %}
89 **Inputs**
91 {% if sch.min_input != sch.max_input %}Between {{sch.min_input
92 }} and {{sch.max_input}} inputs.
93 {% endif %}
94 {% for ii, inp in enumerate(sch.inputs) %}
95 * **{{getname(inp, ii)}}**{{format_option(inp)}} - **{{inp.typeStr}}**:
96 {{text_wrap(inp.description, 2)}}{% endfor %}
97 {% endif %}
99 {% if sch.outputs %}
100 **Outputs**
102 {% if sch.min_output != sch.max_output %}Between {{sch.min_output
103 }} and {{sch.max_output}} outputs.
104 {% endif %}
105 {% for ii, out in enumerate(sch.outputs) %}
106 * **{{getname(out, ii)}}**{{format_option(out)}} - **{{out.typeStr}}**:
107 {{text_wrap(out.description, 2)}}{% endfor %}
108 {% endif %}
110 {% if sch.type_constraints %}
111 **Type Constraints**
113 {% for ii, type_constraint in enumerate(sch.type_constraints)
114 %}* {{get_constraint(type_constraint, ii)}}:
115 {{text_wrap(type_constraint.description, 2)}}
116 {% endfor %}
117 {% endif %}
119 {% if get_onnx_example and is_last_schema(sch): %}
120 **Examples**
122 {% for example, code in get_onnx_example(sch.name).items(): %}
123 **{{ example }}**
125 ::
127 {{ format_example(code) }}
129 {% endfor %}
130 {% endif %}
132 {% endfor %}
133 """))
136_template_operator = _get_doc_template()
137__get_all_schemas_with_history = None
140def _populate__get_all_schemas_with_history():
141 res = {}
142 for schema in onnx.defs.get_all_schemas_with_history():
143 domain = schema.domain
144 version = schema.since_version
145 name = schema.name
146 if domain not in res:
147 res[domain] = {}
148 if name not in res[domain]:
149 res[domain][name] = {}
150 res[domain][name][version] = schema
151 return res
154def _get_all_schemas_with_history():
155 global __get_all_schemas_with_history # pylint: disable=W0603
156 if __get_all_schemas_with_history is None:
157 __get_all_schemas_with_history = _populate__get_all_schemas_with_history()
158 return __get_all_schemas_with_history
161def get_domain_list():
162 """
163 Returns the list of available domains.
164 """
165 return list(sorted(set(map(lambda s: s.domain,
166 onnx.defs.get_all_schemas_with_history()))))
169def get_operator_schemas(op_name, version=None, domain=None):
170 """
171 Returns all schemas mapped to an operator name.
173 :param op_name: name of the operator
174 :param version: version
175 :param domain: domain
176 :return: list of schemas
177 """
178 if version == 'last' and op_name is not None:
179 if domain is not None:
180 return [onnx.defs.get_schema(op_name, domain=domain)]
181 all_schemas = _get_all_schemas_with_history()
182 if domain is None:
183 domains = []
184 for dom, ops in all_schemas.items():
185 if op_name is None or op_name in ops:
186 domains.append(dom)
187 else:
188 domains = [domain]
190 # schemas
191 sch = []
192 for dom in domains:
193 ops = all_schemas[dom]
194 if op_name is None:
195 for op, v in ops.items():
196 if version is None:
197 sch.extend(v.values())
198 elif version == 'last':
199 sch.append(
200 onnx.defs.get_schema(op, domain=dom))
201 else:
202 sch.append(v[version])
203 elif op_name in ops:
204 if version is None:
205 sch.extend(ops[op_name].values())
206 elif version in ops[op_name]:
207 sch.append(ops[op_name][version])
209 # sort
210 vals = [(s.domain, s.name, -s.since_version, s) for s in sch]
211 vals.sort()
212 return [v[-1] for v in vals]
215def get_rst_doc(op_name=None, domain=None, version='last', clean=True,
216 diff=False, example=False):
217 """
218 Returns a documentation in RST format
219 for all :class:`OnnxOperator`.
221 :param op_name: operator name of None for all
222 :param domain: domain
223 :param version: version, None for all, `'last'` for the most recent one
224 :param clean: clean empty lines
225 :param diff: highlights differences between two versions
226 :param example: add example to the documentation
227 :return: string
229 The function relies on module :epkg:`jinja2` or replaces it
230 with a simple rendering if not present.
231 """
232 from ..onnx_tools.onnx2py_helper import _var_as_dict
233 schemas = get_operator_schemas(op_name, domain=domain, version=version)
235 # from onnx.backend.sample.ops import collect_sample_implementations
236 # from onnx.backend.test.case import collect_snippets
237 # SNIPPETS = collect_snippets()
238 # SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
239 def format_name_with_domain(sch):
240 if version == 'last':
241 if sch.domain:
242 return '{} ({})'.format(sch.name, sch.domain)
243 return sch.name
244 if sch.domain:
245 return '{} - {} ({})'.format(sch.name, sch.since_version, sch.domain)
246 return '%s - %d' % (sch.name, sch.since_version)
248 def format_option(obj):
249 opts = []
250 if OpSchema.FormalParameterOption.Optional == obj.option:
251 opts.append('optional')
252 elif OpSchema.FormalParameterOption.Variadic == obj.option:
253 opts.append('variadic')
254 if getattr(obj, 'isHomogeneous', False):
255 opts.append('heterogeneous')
256 if opts:
257 return " (%s)" % ", ".join(opts)
258 return ""
260 def format_example(code):
261 code = textwrap.indent(code, ' ')
262 return code
264 def get_constraint(const, ii):
265 if const.type_param_str:
266 name = const.type_param_str
267 else:
268 name = str(ii)
269 name = "**%s** in (" % name
270 if const.allowed_type_strs:
271 text = ",\n ".join(sorted(const.allowed_type_strs))
272 name += "\n " + text + "\n )"
273 return name
275 def getname(obj, i):
276 name = obj.name
277 if len(name) == 0:
278 return str(i)
279 return name
281 def process_documentation(doc):
282 if doc is None:
283 doc = ''
284 doc = textwrap.dedent(doc)
285 main_docs_url = "https://github.com/onnx/onnx/blob/master/"
286 rep = {
287 '[the doc](IR.md)': '`ONNX <{0}docs/IR.md>`_',
288 '[the doc](Broadcasting.md)':
289 '`Broadcasting in ONNX <{0}docs/Broadcasting.md>`_',
290 '<dl>': '',
291 '</dl>': '',
292 '<dt>': '* ',
293 '<dd>': ' ',
294 '</dt>': '',
295 '</dd>': '',
296 '<tt>': '``',
297 '</tt>': '``',
298 '<br>': '\n',
299 }
300 for k, v in rep.items():
301 doc = doc.replace(k, v.format(main_docs_url))
302 move = 0
303 lines = []
304 for line in doc.split('\n'):
305 if line.startswith("```"):
306 if move > 0:
307 move -= 4
308 lines.append("\n")
309 else:
310 lines.append("::\n")
311 move += 4
312 elif move > 0:
313 lines.append(" " * move + line)
314 else:
315 lines.append(line)
316 return "\n".join(lines)
318 def build_doc_url(sch):
319 doc_url = "https://github.com/onnx/onnx/blob/main/docs/Operators"
320 if "ml" in sch.domain:
321 doc_url += "-ml"
322 doc_url += ".md"
323 doc_url += "#"
324 if sch.domain not in (None, '', 'ai.onnx'):
325 doc_url += sch.domain + "."
326 return doc_url
328 def clean_default_value(value):
329 dvar = _var_as_dict(value)
330 if 'value' in dvar:
331 v = dvar['value']
332 if isinstance(v, bytes):
333 return "Default value is ``'%s'``." % v.decode('ascii')
334 return "Default value is ``{}``.".format(v)
335 else:
336 res = str(value).replace('\n', ' ').strip()
337 if len(res) > 0:
338 return "Default value is ``%s``." % res
339 return ""
341 def text_wrap(text, indent):
342 s = ' ' * indent
343 lines = textwrap.wrap(text, initial_indent=s, subsequent_indent=s)
344 return '\n'.join(lines)
346 fnwd = format_name_with_domain
347 tmpl = _template_operator
348 docs = tmpl.render(schemas=schemas, OpSchema=OpSchema,
349 len=len, getattr=getattr, sorted=sorted,
350 format_option=format_option,
351 get_constraint=get_constraint,
352 getname=getname, enumerate=enumerate,
353 format_name_with_domain=fnwd,
354 process_documentation=process_documentation,
355 build_doc_url=build_doc_url, text_wrap=text_wrap,
356 str=str, clean_default_value=clean_default_value,
357 get_onnx_example=get_onnx_example if example else None,
358 format_example=format_example,
359 is_last_schema=is_last_schema)
360 if diff:
361 lines = docs.split('\n')
362 new_lines = ['']
363 for line in lines:
364 line = line.rstrip('\r\t ')
365 if len(line) == 0 and len(new_lines[-1]) == 0:
366 continue
367 new_lines.append(line)
368 docs = '\n'.join(new_lines)
369 docs = _insert_diff(docs, '.. tag-diff-insert.')
371 if clean:
372 lines = docs.split('\n')
373 new_lines = ['']
374 for line in lines:
375 line = line.rstrip('\r\t ')
376 if len(line) == 0 and len(new_lines[-1]) == 0:
377 continue
378 new_lines.append(line)
379 docs = '\n'.join(new_lines)
381 return docs
384def _insert_diff(docs, split='.. tag-diff-insert.'):
385 """
386 Splits a using `split`, insert HTML differences between pieces.
387 The function relies on package :epkg:`pyquickhelper`.
388 """
389 spl = docs.split(split)
390 if len(spl) <= 1:
391 return docs
393 from pyquickhelper.texthelper.edit_text_diff import (
394 edit_distance_text, diff2html)
396 pieces = [spl[0]]
397 for i in range(1, len(spl)):
398 spl1 = spl[i - 1].strip('\n ')
399 spl2 = spl[i].strip('\n ')
400 spl1 = spl1.split('**Examples**')[0].replace('`', '')
401 spl2 = spl2.split('**Examples**')[0].replace('`', '')
402 spl1 = spl1.split('**Summary**')[-1].strip('\n ')
403 spl2 = spl2.split('**Summary**')[-1].strip('\n ')
404 if len(spl1) < 5 or len(spl2) < 5:
405 pieces.append(spl[i])
406 continue
408 _, aligned, final = edit_distance_text( # pylint: disable=W0632
409 spl2, spl1, threshold=0.5)
410 ht = diff2html(spl2, spl1, aligned, final, two_columns=True)
411 ht = ht.replace(">``<", "><")
412 ht = ' ' + '\n '.join(ht.split('\n'))
413 pieces.extend(['', '**Differences**', '', '.. raw:: html',
414 '', ht, '', spl[i]])
416 return '\n'.join(pieces)
419def get_onnx_example(op_name):
420 """
421 Retrieves examples associated to one operator
422 stored in onnx packages.
424 :param op_name: operator name
425 :param fmt: rendering format
426 :return: dictionary
427 """
428 module = 'onnx.backend.test.case.node.%s' % op_name.lower()
429 try:
430 mod = importlib.import_module(module)
431 except ImportError:
432 return {}
433 results = {}
434 for v in mod.__dict__.values():
435 if not isinstance(v, _Exporter):
436 continue
437 code_cls = inspect.getsource(v)
438 codes = code_cls.split('@staticmethod')
439 for me in v.__dict__:
440 if not me.startswith('export_'):
441 continue
442 sub = ' %s()' % me
443 found = None
444 for code in codes:
445 if sub in code:
446 found = code
447 if found is None:
448 raise RuntimeError(
449 "Unable to find %r in\n%s" % (sub, code_cls))
450 found = textwrap.dedent(found)
451 lines = found.split('\n')
452 first = 0
453 for i in range(len(lines)): # pylint: disable=C0200
454 if lines[i].startswith('def '):
455 first = i + 1
456 found = textwrap.dedent('\n'.join(lines[first:]))
457 results[me[len('export_'):]] = found
458 return results
461def is_last_schema(sch):
462 """
463 Tells if this is the most recent schema for this operator.
465 :param sch: schema
466 :return: True
467 """
468 last = onnx.defs.get_schema(sch.name, domain=sch.domain)
469 return last.since_version == sch.since_version
472def onnx_documentation_folder(folder, ops=None, title='ONNX operators',
473 fLOG=None):
474 """
475 Creates documentation in a folder for all known
476 ONNX operators or a subset.
478 :param folder: folder where to write the documentation
479 :param ops: None for all operators or a subset of them
480 :param title: index title
481 :param fLOG: logging function
482 :return: list of creates files
483 """
484 all_schemas = _get_all_schemas_with_history()
485 if not os.path.exists(folder):
486 os.makedirs(folder)
487 index = ['', title, '=' * len(title), '', '.. contents::', ' :local:',
488 '']
489 pages = []
491 if ops is not None:
492 ops = set(ops)
493 for dom in sorted(all_schemas):
494 sdom = 'main' if dom == '' else dom
495 index_dom = [sdom, '+' * len(sdom), '', '.. toctree::',
496 ' :maxdepth: 1', '']
497 sub = all_schemas[dom]
498 do = []
499 if ops is None:
500 do.extend(sub)
501 else:
502 inter = set(sub).intersection(ops)
503 if len(inter) == 0:
504 continue
505 do.extend(sorted(inter))
506 if len(do) == 0:
507 continue
509 for op in sorted(do):
510 if fLOG is not None:
511 fLOG('generate page for onnx %r - %r' % (dom, op))
512 page_name = "onnx_%s_%s" % (dom.replace('.', ''), op)
513 index_dom.append(' %s' % page_name)
514 doc = get_rst_doc(op, domain=dom, version=None, example=True,
515 diff=True)
516 if dom == '':
517 main = op
518 else:
519 main = '%s - %s' % (dom, op)
520 rows = ['', '.. _l-onnx-doc%s-%s:' % (dom, op), '',
521 '=' * len(main), main, '=' * len(main), '',
522 '.. contents::', ' :local:', '', doc]
524 full = os.path.join(folder, page_name + '.rst')
525 with open(full, 'w', encoding='utf-8') as f:
526 f.write("\n".join(rows))
527 pages.append(full)
528 index.extend(index_dom)
529 index.append('')
531 page_name = os.path.join(folder, 'index.rst')
532 with open(page_name, 'w', encoding='utf-8') as f:
533 f.write('\n'.join(index))
534 pages.append(page_name)
535 return pages