Coverage for mlprodict/onnxrt/ops_cpu/op_softmax_cross_entropy_loss.py: 89%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

70 statements  

1""" 

2@file 

3@brief Runtime operator. 

4""" 

5import numpy 

6from ..shape_object import ShapeObject 

7from ._op import OpRun 

8 

9 

10def softmaxcrossentropy(x, target, weight=None, reduction='mean', 

11 ignore_index=None, get_log_prob=None): 

12 """ 

13 Modified version of `softmaxcrossentropy.py 

14 <https://github.com/onnx/onnx/blob/main/onnx/backend/ 

15 test/case/node/softmaxcrossentropy.py>`_ to handle other type 

16 than float32. 

17 """ 

18 input_shape = x.shape 

19 if len(input_shape) == 1: 

20 raise RuntimeError("Unsupported shape %r." % (input_shape, )) 

21 

22 target_shape = target.shape 

23 N = input_shape[0] 

24 C = input_shape[1] 

25 

26 # compute log_softmax 

27 max_x = numpy.max(x, axis=1, keepdims=True) 

28 exp_x = numpy.exp(x - max_x) 

29 p = exp_x / numpy.sum(exp_x, axis=1, keepdims=True) 

30 inp = numpy.log(p) 

31 log_prob = None 

32 if get_log_prob is True: 

33 log_prob = numpy.copy(inp) 

34 

35 # initialize the positional weights when required 

36 gather_weight = None 

37 if weight is not None: 

38 gather_weight = numpy.take( 

39 weight, numpy.array(target, dtype=numpy.int32), mode='clip') 

40 if ignore_index is not None: 

41 gather_weight = numpy.where( 

42 target == ignore_index, 0, gather_weight).astype(dtype=x.dtype) 

43 elif ignore_index is not None: 

44 gather_weight = numpy.where( 

45 target == ignore_index, 0, 1).astype(dtype=x.dtype) 

46 

47 # if input is 4-d and above, make it 3-d 

48 if len(input_shape) != 3: 

49 inp = inp.reshape((N, C, -1)) 

50 target = target.reshape((N, -1)) 

51 

52 # Get a dimension from the reshaped input. 

53 # If the original input shape is [N, C, H, W], 

54 # the D here should be H * W because we reshape 

55 # [N, C, H, W] to [N, C, H * W]. 

56 D = inp.shape[2] 

57 neg_gather_element_input = numpy.zeros((N, D), dtype=x.dtype) 

58 for i in range(N): 

59 for d in range(D): 

60 if target[i, d] != ignore_index: 

61 neg_gather_element_input[i, d] = -inp[i, target[i, d], d] 

62 

63 loss = neg_gather_element_input 

64 

65 # if the input was 4-d or above reshape to the right shape 

66 if len(input_shape) != 3: 

67 loss = loss.reshape(target_shape) 

68 

69 # apply the weights when required 

70 if gather_weight is not None: 

71 loss = gather_weight * loss 

72 if reduction == b'mean': 

73 loss = loss.sum() / gather_weight.sum() 

74 if get_log_prob is True: 

75 return loss, log_prob 

76 return (loss, ) 

77 

78 if reduction == b'mean': 

79 loss = numpy.mean(loss) 

80 elif reduction == b'sum': 

81 loss = numpy.sum(loss) 

82 

83 if get_log_prob is True: 

84 return loss, log_prob 

85 return (loss, ) 

86 

87 

88class SoftmaxCrossEntropyLoss(OpRun): 

89 """ 

90 Python runtime for function *SoftmaxCrossEntropyLoss*. 

91 """ 

92 

93 atts = {'reduction': b'mean', 'ignore_index': -1} 

94 

95 def __init__(self, onnx_node, desc=None, **options): 

96 OpRun.__init__(self, onnx_node, desc=desc, 

97 expected_attributes=SoftmaxCrossEntropyLoss.atts, 

98 **options) 

99 

100 def _run(self, x, target, weight=None): # pylint: disable=W0221 

101 n_outputs = len(self.onnx_node.output) 

102 return softmaxcrossentropy( 

103 x, target, weight=weight, reduction=self.reduction, # pylint: disable=E1101 

104 ignore_index=self.ignore_index, # pylint: disable=E1101 

105 get_log_prob=n_outputs == 2) 

106 

107 def _infer_shapes(self, x, target, weight=None): # pylint: disable=W0221 

108 n_outputs = len(self.onnx_node.output) 

109 if n_outputs == 1: 

110 return (ShapeObject(None, dtype=x.dtype), ) 

111 return (ShapeObject(None, dtype=x.dtype), 

112 ShapeObject(None, dtype=x.dtype)) 

113 

114 def _infer_types(self, x, target, weight=None): # pylint: disable=W0221 

115 n_outputs = len(self.onnx_node.output) 

116 if n_outputs == 1: 

117 return (x.dtype, ) 

118 return (x.dtype, x.dtype) 

119 

120 def _infer_sizes(self, *args): # pylint: disable=W0221 

121 res = self.run(*args) 

122 return (dict(temp=0), ) + res