Coverage for mlprodict/onnxrt/ops_cpu/op_softmax_cross_entropy_loss.py: 89%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Runtime operator.
4"""
5import numpy
6from ..shape_object import ShapeObject
7from ._op import OpRun
10def softmaxcrossentropy(x, target, weight=None, reduction='mean',
11 ignore_index=None, get_log_prob=None):
12 """
13 Modified version of `softmaxcrossentropy.py
14 <https://github.com/onnx/onnx/blob/main/onnx/backend/
15 test/case/node/softmaxcrossentropy.py>`_ to handle other type
16 than float32.
17 """
18 input_shape = x.shape
19 if len(input_shape) == 1:
20 raise RuntimeError("Unsupported shape %r." % (input_shape, ))
22 target_shape = target.shape
23 N = input_shape[0]
24 C = input_shape[1]
26 # compute log_softmax
27 max_x = numpy.max(x, axis=1, keepdims=True)
28 exp_x = numpy.exp(x - max_x)
29 p = exp_x / numpy.sum(exp_x, axis=1, keepdims=True)
30 inp = numpy.log(p)
31 log_prob = None
32 if get_log_prob is True:
33 log_prob = numpy.copy(inp)
35 # initialize the positional weights when required
36 gather_weight = None
37 if weight is not None:
38 gather_weight = numpy.take(
39 weight, numpy.array(target, dtype=numpy.int32), mode='clip')
40 if ignore_index is not None:
41 gather_weight = numpy.where(
42 target == ignore_index, 0, gather_weight).astype(dtype=x.dtype)
43 elif ignore_index is not None:
44 gather_weight = numpy.where(
45 target == ignore_index, 0, 1).astype(dtype=x.dtype)
47 # if input is 4-d and above, make it 3-d
48 if len(input_shape) != 3:
49 inp = inp.reshape((N, C, -1))
50 target = target.reshape((N, -1))
52 # Get a dimension from the reshaped input.
53 # If the original input shape is [N, C, H, W],
54 # the D here should be H * W because we reshape
55 # [N, C, H, W] to [N, C, H * W].
56 D = inp.shape[2]
57 neg_gather_element_input = numpy.zeros((N, D), dtype=x.dtype)
58 for i in range(N):
59 for d in range(D):
60 if target[i, d] != ignore_index:
61 neg_gather_element_input[i, d] = -inp[i, target[i, d], d]
63 loss = neg_gather_element_input
65 # if the input was 4-d or above reshape to the right shape
66 if len(input_shape) != 3:
67 loss = loss.reshape(target_shape)
69 # apply the weights when required
70 if gather_weight is not None:
71 loss = gather_weight * loss
72 if reduction == b'mean':
73 loss = loss.sum() / gather_weight.sum()
74 if get_log_prob is True:
75 return loss, log_prob
76 return (loss, )
78 if reduction == b'mean':
79 loss = numpy.mean(loss)
80 elif reduction == b'sum':
81 loss = numpy.sum(loss)
83 if get_log_prob is True:
84 return loss, log_prob
85 return (loss, )
88class SoftmaxCrossEntropyLoss(OpRun):
89 """
90 Python runtime for function *SoftmaxCrossEntropyLoss*.
91 """
93 atts = {'reduction': b'mean', 'ignore_index': -1}
95 def __init__(self, onnx_node, desc=None, **options):
96 OpRun.__init__(self, onnx_node, desc=desc,
97 expected_attributes=SoftmaxCrossEntropyLoss.atts,
98 **options)
100 def _run(self, x, target, weight=None): # pylint: disable=W0221
101 n_outputs = len(self.onnx_node.output)
102 return softmaxcrossentropy(
103 x, target, weight=weight, reduction=self.reduction, # pylint: disable=E1101
104 ignore_index=self.ignore_index, # pylint: disable=E1101
105 get_log_prob=n_outputs == 2)
107 def _infer_shapes(self, x, target, weight=None): # pylint: disable=W0221
108 n_outputs = len(self.onnx_node.output)
109 if n_outputs == 1:
110 return (ShapeObject(None, dtype=x.dtype), )
111 return (ShapeObject(None, dtype=x.dtype),
112 ShapeObject(None, dtype=x.dtype))
114 def _infer_types(self, x, target, weight=None): # pylint: disable=W0221
115 n_outputs = len(self.onnx_node.output)
116 if n_outputs == 1:
117 return (x.dtype, )
118 return (x.dtype, x.dtype)
120 def _infer_sizes(self, *args): # pylint: disable=W0221
121 res = self.run(*args)
122 return (dict(temp=0), ) + res