Coverage for mlprodict/onnxrt/ops_cpu/op_concat_from_sequence.py: 88%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

24 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ..shape_object import ShapeObject 

10 

11 

12class ConcatFromSequence(OpRun): 

13 

14 atts = {'axis': 0, 'new_axis': 0} 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=ConcatFromSequence.atts, 

19 **options) 

20 

21 def _run(self, seq): # pylint: disable=W0221 

22 if seq is None: 

23 raise RuntimeError( # pragma: no cover 

24 "A sequence cannot be null.") 

25 if self.new_axis == 1: 

26 seq2 = [s[..., numpy.newaxis] for s in seq] 

27 res = numpy.concatenate(seq2, axis=-1) 

28 else: 

29 res = numpy.concatenate(seq, axis=self.axis) 

30 return (res, ) 

31 

32 def _infer_shapes(self, seq): # pylint: disable=W0221 

33 return (ShapeObject(None, seq.dtype), ) 

34 

35 def _infer_types(self, seq): # pylint: disable=W0221 

36 return (seq, ) 

37 

38 def _infer_sizes(self, seq): # pylint: disable=W0221 

39 res = self.run(seq) 

40 if self.new_axis == 1: 

41 return (dict(temp=sum(o.size for o in seq)), ) + res 

42 return (dict(temp=0), ) + res