Coverage for mlprodict/onnxrt/ops_cpu/op_broadcast_gradient_args.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ..shape_object import ShapeObject
9from ._op import OpRun
10from ._new_ops import OperatorSchema
13class BroadcastGradientArgs(OpRun):
15 atts = {}
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRun.__init__(self, onnx_node, desc=desc,
19 **options)
21 def _find_custom_operator_schema(self, op_name):
22 if op_name == "BroadcastGradientArgs":
23 return BroadcastGradientArgsSchema()
24 raise RuntimeError( # pragma: no cover
25 "Unable to find a schema for operator '{}'.".format(op_name))
27 def _run(self, a_shape, b_shape): # pylint: disable=W0221
29 A_dims = a_shape
30 B_dims = b_shape
31 a_size = len(a_shape)
32 b_size = len(b_shape)
34 ndim = max(a_size, b_size)
36 i = a_size - 1
37 j = b_size - 1
38 k = ndim - 1
40 a_axes = []
41 b_axes = []
43 while i >= 0 and j >= 0:
44 A_dim = A_dims[i]
45 B_dim = B_dims[j]
47 if A_dim != B_dim:
48 if A_dim == 1:
49 a_axes.append(k)
50 elif B_dim == 1:
51 b_axes.append(k)
52 else:
53 a = A_dims[:a_size]
54 b = B_dims[:b_size]
55 raise RuntimeError(
56 "Broadcast is not possible between inputs of "
57 "shapes: %r and %r." % (a, b))
58 i -= 1
59 j -= 1
60 k -= 1
62 if i < 0:
63 while k >= 0:
64 a_axes.append(k)
65 k -= 1
66 else:
67 while k >= 0:
68 b_axes.append(k)
69 k -= 1
71 return (numpy.array(a_axes, dtype=numpy.int64),
72 numpy.array(b_axes, dtype=numpy.int64))
74 def _infer_shapes(self, a, b): # pylint: disable=W0221,W0237
75 return (ShapeObject(None, dtype=numpy.int64),
76 ShapeObject(None, dtype=numpy.int64))
78 def _infer_types(self, a, b): # pylint: disable=W0221,W0237
79 return (a.dtype, b.dtype)
82class BroadcastGradientArgsSchema(OperatorSchema):
83 """
84 Defines a schema for operators added in this package
85 such as @see cl BroadcastGradientArgs.
86 """
88 def __init__(self):
89 OperatorSchema.__init__(self, 'BroadcastGradientArgs')
90 self.attributes = BroadcastGradientArgs.atts