Coverage for mlprodict/onnxrt/ops_cpu/op_broadcast_gradient_args.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

53 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ..shape_object import ShapeObject 

9from ._op import OpRun 

10from ._new_ops import OperatorSchema 

11 

12 

13class BroadcastGradientArgs(OpRun): 

14 

15 atts = {} 

16 

17 def __init__(self, onnx_node, desc=None, **options): 

18 OpRun.__init__(self, onnx_node, desc=desc, 

19 **options) 

20 

21 def _find_custom_operator_schema(self, op_name): 

22 if op_name == "BroadcastGradientArgs": 

23 return BroadcastGradientArgsSchema() 

24 raise RuntimeError( # pragma: no cover 

25 "Unable to find a schema for operator '{}'.".format(op_name)) 

26 

27 def _run(self, a_shape, b_shape): # pylint: disable=W0221 

28 

29 A_dims = a_shape 

30 B_dims = b_shape 

31 a_size = len(a_shape) 

32 b_size = len(b_shape) 

33 

34 ndim = max(a_size, b_size) 

35 

36 i = a_size - 1 

37 j = b_size - 1 

38 k = ndim - 1 

39 

40 a_axes = [] 

41 b_axes = [] 

42 

43 while i >= 0 and j >= 0: 

44 A_dim = A_dims[i] 

45 B_dim = B_dims[j] 

46 

47 if A_dim != B_dim: 

48 if A_dim == 1: 

49 a_axes.append(k) 

50 elif B_dim == 1: 

51 b_axes.append(k) 

52 else: 

53 a = A_dims[:a_size] 

54 b = B_dims[:b_size] 

55 raise RuntimeError( 

56 "Broadcast is not possible between inputs of " 

57 "shapes: %r and %r." % (a, b)) 

58 i -= 1 

59 j -= 1 

60 k -= 1 

61 

62 if i < 0: 

63 while k >= 0: 

64 a_axes.append(k) 

65 k -= 1 

66 else: 

67 while k >= 0: 

68 b_axes.append(k) 

69 k -= 1 

70 

71 return (numpy.array(a_axes, dtype=numpy.int64), 

72 numpy.array(b_axes, dtype=numpy.int64)) 

73 

74 def _infer_shapes(self, a, b): # pylint: disable=W0221,W0237 

75 return (ShapeObject(None, dtype=numpy.int64), 

76 ShapeObject(None, dtype=numpy.int64)) 

77 

78 def _infer_types(self, a, b): # pylint: disable=W0221,W0237 

79 return (a.dtype, b.dtype) 

80 

81 

82class BroadcastGradientArgsSchema(OperatorSchema): 

83 """ 

84 Defines a schema for operators added in this package 

85 such as @see cl BroadcastGradientArgs. 

86 """ 

87 

88 def __init__(self): 

89 OperatorSchema.__init__(self, 'BroadcastGradientArgs') 

90 self.attributes = BroadcastGradientArgs.atts