Coverage for mlprodict/onnxrt/ops_shape/_op_shape_op.py: 89%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

19 statements  

1""" 

2@file 

3@brief Computes shape inference for onnx operators. 

4""" 

5from .shape_excs import ShapeInferenceException 

6from .shape_result import ( 

7 ShapeResult, OnnxKind, ShapeConstraintList, ShapeConstraint) 

8 

9 

10def shape_det(known_shapes, node): 

11 "Infers shape for operator Abs." 

12 x = known_shapes[node.input[0]] 

13 if x.mtype != OnnxKind.Tensor: 

14 raise ShapeInferenceException( # pragma: no cover 

15 "Result %r must be a tensor." % x) 

16 if x.n_dims() < 2: 

17 raise ShapeInferenceException( # pragma: no cover 

18 "Operator Det requires at least two dimensions not %r." % x.n_dims()) 

19 name = node.output[0] 

20 

21 constraints = ShapeConstraintList() 

22 a, b = x.shape[-2:] 

23 if isinstance(a, int) and isinstance(b, int): 

24 if a != b: 

25 raise ShapeInferenceException( # pragma: no cover 

26 "Operator Det only applies on square matrices not %r." % x.n_dims()) 

27 elif isinstance(a, str): 

28 constraints.append(ShapeConstraint(a, {b})) 

29 elif isinstance(b, str): 

30 constraints.append(ShapeConstraint(b, {a})) 

31 else: 

32 raise ShapeInferenceException( # pragma: no cover 

33 "Unexpected case for operator Det (%r)." % x) 

34 if x.n_dims() == 2: 

35 r = ShapeResult(name, [], x.dtype, False, 

36 x.mtype, constraints) 

37 else: 

38 r = ShapeResult(name, x.shape[:-2], x.dtype, False, 

39 x.mtype, constraints) 

40 return known_shapes.update(name, r)