Coverage for mlprodict/npy/xop_auto.py: 94%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

266 statements  

1""" 

2@file 

3@brief Automates the generation of operators for the 

4documentation for the Xop API. 

5 

6.. versionadded:: 0.9 

7""" 

8import os 

9import textwrap 

10import importlib 

11import inspect 

12import onnx 

13import onnx.defs 

14from onnx.backend.test.case.base import _Exporter 

15from onnx.defs import OpSchema 

16 

17 

18def _get_doc_template(): 

19 try: 

20 from jinja2 import Template 

21 except ImportError: # pragma no cover 

22 class Template: 

23 "Docstring template" 

24 

25 def __init__(self, *args): 

26 pass 

27 

28 def render(self, **context): 

29 "render" 

30 schemas = context['schemas'] 

31 rows = [] 

32 for sch in schemas: 

33 doc = sch.doc or '' 

34 name = sch.name 

35 if name is None: 

36 raise RuntimeError("An operator must have a name.") 

37 rows.extend([name, "=" * len(name), 

38 "", doc, ""]) 

39 return "\n".join(rows) 

40 

41 return Template(textwrap.dedent(""" 

42 {% for sch in schemas %} 

43 

44 .. tag-diff-insert. 

45 

46 .. _l-onnx-op{{sch.domain.lower().replace(".", "-")}}-{{sch.name.lower()}}-{{str(sch.since_version)}}: 

47 

48 {{format_name_with_domain(sch)}} 

49 {{'=' * len(format_name_with_domain(sch))}} 

50 

51 **Version** 

52 

53 * **name**: `{{sch.name}} (GitHub) <{{build_doc_url(sch)}}{{sch.name}}>`_ 

54 * **domain**: **{% if sch.domain == '' %}main{% else %}{{sch.domain}}{% endif %}** 

55 * **since_version**: **{{sch.since_version}}** 

56 * **function**: {{sch.has_function}} 

57 * **support_level**: {{sch.support_level}} 

58 * **shape inference**: {{sch.has_type_and_shape_inference_function}} 

59 

60 {% if sch.support_level == OpSchema.SupportType.EXPERIMENTAL %} 

61 No versioning maintained for experimental ops. 

62 {% else %} 

63 This version of the operator has been {% if 

64 sch.deprecated %}deprecated{% else %}available{% endif %} 

65 **since version {{sch.since_version}}{% if 

66 sch.domain %} of domain {{sch.domain}}{% endif %}**. 

67 {% if len(sch.versions) > 1 %} 

68 Other versions of this operator: 

69 {% for v in sch.version[:-1] %} {{v}} {% endfor %} 

70 {% endif %} 

71 {% endif %} 

72 

73 **Summary** 

74 

75 {{process_documentation(sch.doc)}} 

76 

77 {% if sch.attributes %} 

78 **Attributes** 

79 

80 {% for _, attr in sorted(sch.attributes.items()) %}* **{{attr.name}}**{% 

81 if attr.required %} (required){% endif %}: 

82 {{text_wrap(attr.description, 2)}} {% 

83 if attr.default_value %}{{clean_default_value(attr.default_value)}}{% 

84 endif %} 

85 {% endfor %} 

86 {% endif %} 

87 

88 {% if sch.inputs %} 

89 **Inputs** 

90 

91 {% if sch.min_input != sch.max_input %}Between {{sch.min_input 

92 }} and {{sch.max_input}} inputs. 

93 {% endif %} 

94 {% for ii, inp in enumerate(sch.inputs) %} 

95 * **{{getname(inp, ii)}}**{{format_option(inp)}} - **{{inp.typeStr}}**: 

96 {{text_wrap(inp.description, 2)}}{% endfor %} 

97 {% endif %} 

98 

99 {% if sch.outputs %} 

100 **Outputs** 

101 

102 {% if sch.min_output != sch.max_output %}Between {{sch.min_output 

103 }} and {{sch.max_output}} outputs. 

104 {% endif %} 

105 {% for ii, out in enumerate(sch.outputs) %} 

106 * **{{getname(out, ii)}}**{{format_option(out)}} - **{{out.typeStr}}**: 

107 {{text_wrap(out.description, 2)}}{% endfor %} 

108 {% endif %} 

109 

110 {% if sch.type_constraints %} 

111 **Type Constraints** 

112 

113 {% for ii, type_constraint in enumerate(sch.type_constraints) 

114 %}* {{get_constraint(type_constraint, ii)}}: 

115 {{text_wrap(type_constraint.description, 2)}} 

116 {% endfor %} 

117 {% endif %} 

118 

119 {% if get_onnx_example and is_last_schema(sch): %} 

120 **Examples** 

121 

122 {% for example, code in get_onnx_example(sch.name).items(): %} 

123 **{{ example }}** 

124 

125 :: 

126 

127 {{ format_example(code) }} 

128 

129 {% endfor %} 

130 {% endif %} 

131 

132 {% endfor %} 

133 """)) 

134 

135 

136_template_operator = _get_doc_template() 

137__get_all_schemas_with_history = None 

138 

139 

140def _populate__get_all_schemas_with_history(): 

141 res = {} 

142 for schema in onnx.defs.get_all_schemas_with_history(): 

143 domain = schema.domain 

144 version = schema.since_version 

145 name = schema.name 

146 if domain not in res: 

147 res[domain] = {} 

148 if name not in res[domain]: 

149 res[domain][name] = {} 

150 res[domain][name][version] = schema 

151 return res 

152 

153 

154def _get_all_schemas_with_history(): 

155 global __get_all_schemas_with_history # pylint: disable=W0603 

156 if __get_all_schemas_with_history is None: 

157 __get_all_schemas_with_history = _populate__get_all_schemas_with_history() 

158 return __get_all_schemas_with_history 

159 

160 

161def get_domain_list(): 

162 """ 

163 Returns the list of available domains. 

164 """ 

165 return list(sorted(set(map(lambda s: s.domain, 

166 onnx.defs.get_all_schemas_with_history())))) 

167 

168 

169def get_operator_schemas(op_name, version=None, domain=None): 

170 """ 

171 Returns all schemas mapped to an operator name. 

172 

173 :param op_name: name of the operator 

174 :param version: version 

175 :param domain: domain 

176 :return: list of schemas 

177 """ 

178 if version == 'last' and op_name is not None: 

179 if domain is not None: 

180 return [onnx.defs.get_schema(op_name, domain=domain)] 

181 all_schemas = _get_all_schemas_with_history() 

182 if domain is None: 

183 domains = [] 

184 for dom, ops in all_schemas.items(): 

185 if op_name is None or op_name in ops: 

186 domains.append(dom) 

187 else: 

188 domains = [domain] 

189 

190 # schemas 

191 sch = [] 

192 for dom in domains: 

193 ops = all_schemas[dom] 

194 if op_name is None: 

195 for op, v in ops.items(): 

196 if version is None: 

197 sch.extend(v.values()) 

198 elif version == 'last': 

199 sch.append( 

200 onnx.defs.get_schema(op, domain=dom)) 

201 else: 

202 sch.append(v[version]) 

203 elif op_name in ops: 

204 if version is None: 

205 sch.extend(ops[op_name].values()) 

206 elif version in ops[op_name]: 

207 sch.append(ops[op_name][version]) 

208 

209 # sort 

210 vals = [(s.domain, s.name, -s.since_version, s) for s in sch] 

211 vals.sort() 

212 return [v[-1] for v in vals] 

213 

214 

215def get_rst_doc(op_name=None, domain=None, version='last', clean=True, 

216 diff=False, example=False): 

217 """ 

218 Returns a documentation in RST format 

219 for all :class:`OnnxOperator`. 

220 

221 :param op_name: operator name of None for all 

222 :param domain: domain 

223 :param version: version, None for all, `'last'` for the most recent one 

224 :param clean: clean empty lines 

225 :param diff: highlights differences between two versions 

226 :param example: add example to the documentation 

227 :return: string 

228 

229 The function relies on module :epkg:`jinja2` or replaces it 

230 with a simple rendering if not present. 

231 """ 

232 from ..onnx_tools.onnx2py_helper import _var_as_dict 

233 schemas = get_operator_schemas(op_name, domain=domain, version=version) 

234 

235 # from onnx.backend.sample.ops import collect_sample_implementations 

236 # from onnx.backend.test.case import collect_snippets 

237 # SNIPPETS = collect_snippets() 

238 # SAMPLE_IMPLEMENTATIONS = collect_sample_implementations() 

239 def format_name_with_domain(sch): 

240 if version == 'last': 

241 if sch.domain: 

242 return '{} ({})'.format(sch.name, sch.domain) 

243 return sch.name 

244 if sch.domain: 

245 return '{} - {} ({})'.format(sch.name, sch.since_version, sch.domain) 

246 return '%s - %d' % (sch.name, sch.since_version) 

247 

248 def format_option(obj): 

249 opts = [] 

250 if OpSchema.FormalParameterOption.Optional == obj.option: 

251 opts.append('optional') 

252 elif OpSchema.FormalParameterOption.Variadic == obj.option: 

253 opts.append('variadic') 

254 if getattr(obj, 'isHomogeneous', False): 

255 opts.append('heterogeneous') 

256 if opts: 

257 return " (%s)" % ", ".join(opts) 

258 return "" 

259 

260 def format_example(code): 

261 code = textwrap.indent(code, ' ') 

262 return code 

263 

264 def get_constraint(const, ii): 

265 if const.type_param_str: 

266 name = const.type_param_str 

267 else: 

268 name = str(ii) 

269 name = "**%s** in (" % name 

270 if const.allowed_type_strs: 

271 text = ",\n ".join(sorted(const.allowed_type_strs)) 

272 name += "\n " + text + "\n )" 

273 return name 

274 

275 def getname(obj, i): 

276 name = obj.name 

277 if len(name) == 0: 

278 return str(i) 

279 return name 

280 

281 def process_documentation(doc): 

282 if doc is None: 

283 doc = '' 

284 doc = textwrap.dedent(doc) 

285 main_docs_url = "https://github.com/onnx/onnx/blob/master/" 

286 rep = { 

287 '[the doc](IR.md)': '`ONNX <{0}docs/IR.md>`_', 

288 '[the doc](Broadcasting.md)': 

289 '`Broadcasting in ONNX <{0}docs/Broadcasting.md>`_', 

290 '<dl>': '', 

291 '</dl>': '', 

292 '<dt>': '* ', 

293 '<dd>': ' ', 

294 '</dt>': '', 

295 '</dd>': '', 

296 '<tt>': '``', 

297 '</tt>': '``', 

298 '<br>': '\n', 

299 } 

300 for k, v in rep.items(): 

301 doc = doc.replace(k, v.format(main_docs_url)) 

302 move = 0 

303 lines = [] 

304 for line in doc.split('\n'): 

305 if line.startswith("```"): 

306 if move > 0: 

307 move -= 4 

308 lines.append("\n") 

309 else: 

310 lines.append("::\n") 

311 move += 4 

312 elif move > 0: 

313 lines.append(" " * move + line) 

314 else: 

315 lines.append(line) 

316 return "\n".join(lines) 

317 

318 def build_doc_url(sch): 

319 doc_url = "https://github.com/onnx/onnx/blob/main/docs/Operators" 

320 if "ml" in sch.domain: 

321 doc_url += "-ml" 

322 doc_url += ".md" 

323 doc_url += "#" 

324 if sch.domain not in (None, '', 'ai.onnx'): 

325 doc_url += sch.domain + "." 

326 return doc_url 

327 

328 def clean_default_value(value): 

329 dvar = _var_as_dict(value) 

330 if 'value' in dvar: 

331 v = dvar['value'] 

332 if isinstance(v, bytes): 

333 return "Default value is ``'%s'``." % v.decode('ascii') 

334 return "Default value is ``{}``.".format(v) 

335 else: 

336 res = str(value).replace('\n', ' ').strip() 

337 if len(res) > 0: 

338 return "Default value is ``%s``." % res 

339 return "" 

340 

341 def text_wrap(text, indent): 

342 s = ' ' * indent 

343 lines = textwrap.wrap(text, initial_indent=s, subsequent_indent=s) 

344 return '\n'.join(lines) 

345 

346 fnwd = format_name_with_domain 

347 tmpl = _template_operator 

348 docs = tmpl.render(schemas=schemas, OpSchema=OpSchema, 

349 len=len, getattr=getattr, sorted=sorted, 

350 format_option=format_option, 

351 get_constraint=get_constraint, 

352 getname=getname, enumerate=enumerate, 

353 format_name_with_domain=fnwd, 

354 process_documentation=process_documentation, 

355 build_doc_url=build_doc_url, text_wrap=text_wrap, 

356 str=str, clean_default_value=clean_default_value, 

357 get_onnx_example=get_onnx_example if example else None, 

358 format_example=format_example, 

359 is_last_schema=is_last_schema) 

360 if diff: 

361 lines = docs.split('\n') 

362 new_lines = [''] 

363 for line in lines: 

364 line = line.rstrip('\r\t ') 

365 if len(line) == 0 and len(new_lines[-1]) == 0: 

366 continue 

367 new_lines.append(line) 

368 docs = '\n'.join(new_lines) 

369 docs = _insert_diff(docs, '.. tag-diff-insert.') 

370 

371 if clean: 

372 lines = docs.split('\n') 

373 new_lines = [''] 

374 for line in lines: 

375 line = line.rstrip('\r\t ') 

376 if len(line) == 0 and len(new_lines[-1]) == 0: 

377 continue 

378 new_lines.append(line) 

379 docs = '\n'.join(new_lines) 

380 

381 return docs 

382 

383 

384def _insert_diff(docs, split='.. tag-diff-insert.'): 

385 """ 

386 Splits a using `split`, insert HTML differences between pieces. 

387 The function relies on package :epkg:`pyquickhelper`. 

388 """ 

389 spl = docs.split(split) 

390 if len(spl) <= 1: 

391 return docs 

392 

393 from pyquickhelper.texthelper.edit_text_diff import ( 

394 edit_distance_text, diff2html) 

395 

396 pieces = [spl[0]] 

397 for i in range(1, len(spl)): 

398 spl1 = spl[i - 1].strip('\n ') 

399 spl2 = spl[i].strip('\n ') 

400 spl1 = spl1.split('**Examples**')[0].replace('`', '') 

401 spl2 = spl2.split('**Examples**')[0].replace('`', '') 

402 spl1 = spl1.split('**Summary**')[-1].strip('\n ') 

403 spl2 = spl2.split('**Summary**')[-1].strip('\n ') 

404 if len(spl1) < 5 or len(spl2) < 5: 

405 pieces.append(spl[i]) 

406 continue 

407 

408 _, aligned, final = edit_distance_text( # pylint: disable=W0632 

409 spl2, spl1, threshold=0.5) 

410 ht = diff2html(spl2, spl1, aligned, final, two_columns=True) 

411 ht = ht.replace(">``<", "><") 

412 ht = ' ' + '\n '.join(ht.split('\n')) 

413 pieces.extend(['', '**Differences**', '', '.. raw:: html', 

414 '', ht, '', spl[i]]) 

415 

416 return '\n'.join(pieces) 

417 

418 

419def get_onnx_example(op_name): 

420 """ 

421 Retrieves examples associated to one operator 

422 stored in onnx packages. 

423 

424 :param op_name: operator name 

425 :param fmt: rendering format 

426 :return: dictionary 

427 """ 

428 module = 'onnx.backend.test.case.node.%s' % op_name.lower() 

429 try: 

430 mod = importlib.import_module(module) 

431 except ImportError: 

432 return {} 

433 results = {} 

434 for v in mod.__dict__.values(): 

435 if not isinstance(v, _Exporter): 

436 continue 

437 code_cls = inspect.getsource(v) 

438 codes = code_cls.split('@staticmethod') 

439 for me in v.__dict__: 

440 if not me.startswith('export_'): 

441 continue 

442 sub = ' %s()' % me 

443 found = None 

444 for code in codes: 

445 if sub in code: 

446 found = code 

447 if found is None: 

448 raise RuntimeError( 

449 "Unable to find %r in\n%s" % (sub, code_cls)) 

450 found = textwrap.dedent(found) 

451 lines = found.split('\n') 

452 first = 0 

453 for i in range(len(lines)): # pylint: disable=C0200 

454 if lines[i].startswith('def '): 

455 first = i + 1 

456 found = textwrap.dedent('\n'.join(lines[first:])) 

457 results[me[len('export_'):]] = found 

458 return results 

459 

460 

461def is_last_schema(sch): 

462 """ 

463 Tells if this is the most recent schema for this operator. 

464 

465 :param sch: schema 

466 :return: True 

467 """ 

468 last = onnx.defs.get_schema(sch.name, domain=sch.domain) 

469 return last.since_version == sch.since_version 

470 

471 

472def onnx_documentation_folder(folder, ops=None, title='ONNX operators', 

473 fLOG=None): 

474 """ 

475 Creates documentation in a folder for all known 

476 ONNX operators or a subset. 

477 

478 :param folder: folder where to write the documentation 

479 :param ops: None for all operators or a subset of them 

480 :param title: index title 

481 :param fLOG: logging function 

482 :return: list of creates files 

483 """ 

484 all_schemas = _get_all_schemas_with_history() 

485 if not os.path.exists(folder): 

486 os.makedirs(folder) 

487 index = ['', title, '=' * len(title), '', '.. contents::', ' :local:', 

488 ''] 

489 pages = [] 

490 

491 if ops is not None: 

492 ops = set(ops) 

493 for dom in sorted(all_schemas): 

494 sdom = 'main' if dom == '' else dom 

495 index_dom = [sdom, '+' * len(sdom), '', '.. toctree::', 

496 ' :maxdepth: 1', ''] 

497 sub = all_schemas[dom] 

498 do = [] 

499 if ops is None: 

500 do.extend(sub) 

501 else: 

502 inter = set(sub).intersection(ops) 

503 if len(inter) == 0: 

504 continue 

505 do.extend(sorted(inter)) 

506 if len(do) == 0: 

507 continue 

508 

509 for op in sorted(do): 

510 if fLOG is not None: 

511 fLOG('generate page for onnx %r - %r' % (dom, op)) 

512 page_name = "onnx_%s_%s" % (dom.replace('.', ''), op) 

513 index_dom.append(' %s' % page_name) 

514 doc = get_rst_doc(op, domain=dom, version=None, example=True, 

515 diff=True) 

516 if dom == '': 

517 main = op 

518 else: 

519 main = '%s - %s' % (dom, op) 

520 rows = ['', '.. _l-onnx-doc%s-%s:' % (dom, op), '', 

521 '=' * len(main), main, '=' * len(main), '', 

522 '.. contents::', ' :local:', '', doc] 

523 

524 full = os.path.join(folder, page_name + '.rst') 

525 with open(full, 'w', encoding='utf-8') as f: 

526 f.write("\n".join(rows)) 

527 pages.append(full) 

528 index.extend(index_dom) 

529 index.append('') 

530 

531 page_name = os.path.join(folder, 'index.rst') 

532 with open(page_name, 'w', encoding='utf-8') as f: 

533 f.write('\n'.join(index)) 

534 pages.append(page_name) 

535 return pages