Coverage for mlprodict/onnxrt/ops_shape/shape_result.py: 92%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

165 statements  

1""" 

2@file 

3@brief Class ShapeResult 

4""" 

5from enum import Enum 

6import numpy 

7from .shape_excs import ShapeInferenceException 

8 

9 

10class OnnxKind(Enum): 

11 """ 

12 Describes a result type. 

13 """ 

14 Tensor = 0 

15 Sequence = 1 

16 Map = 2 

17 

18 

19class ShapeConstraint: 

20 """ 

21 One constraint. 

22 

23 :param name: variable name 

24 :param values: set of possible values 

25 """ 

26 

27 def __init__(self, name, values): 

28 if name == '?': 

29 raise ValueError( # pragma: no cover 

30 "Name cannot be '?'.") 

31 if not isinstance(values, set): 

32 raise TypeError( # pragma: no cover 

33 "values must be a set not %r." % type(values)) 

34 self.name = name 

35 self.values = values 

36 

37 def __eq__(self, other): 

38 "usual" 

39 if self.name != other.name: 

40 return False 

41 if self.values != other.values: 

42 return False 

43 return True 

44 

45 def __repr__(self): 

46 "usual" 

47 return "%s(%r, %r)" % ( 

48 self.__class__.__name__, self.name, self.values) 

49 

50 def merge(self, cst): 

51 """ 

52 Merges this constraint with *cst* into this one. 

53 """ 

54 if isinstance(cst, list): 

55 for c in cst: 

56 self.merge(c) 

57 return 

58 self.values = self.values.intersection(cst.values) 

59 

60 def copy(self, deep=False): 

61 """ 

62 Makes a copy of the object. 

63 """ 

64 return ShapeConstraint(self.name, self.values.copy()) 

65 

66 

67class ShapeConstraintList: 

68 """ 

69 A list of ShapeConstraint. 

70 """ 

71 

72 def __init__(self): 

73 self.csts = [] 

74 

75 def __contains__(self, cst): 

76 for a in self.csts: 

77 if cst == a: 

78 return True 

79 return False 

80 

81 def append(self, cst): 

82 "Appends a new constraint to the list." 

83 self.csts.append(cst) 

84 

85 def __repr__(self): 

86 return "ShapeConstraintList(%r)" % self.csts 

87 

88 def __iter__(self): 

89 for c in self.csts: 

90 yield c 

91 

92 def __len__(self): 

93 return len(self.csts) 

94 

95 def copy(self, deep=False): 

96 """ 

97 Copies the object. 

98 """ 

99 cp = ShapeConstraintList() 

100 if deep: 

101 cp.csts = [v.copy(deep=deep) for v in self] 

102 else: 

103 cp.csts = self.csts.copy() 

104 return cp 

105 

106 

107class ShapeResult: 

108 """ 

109 Contains information about shape and type of a result 

110 in an onnx graph. 

111 

112 :param name: result name 

113 :param shape: shape if the result is a tensor 

114 :param dtype: element type if the result is a tensor 

115 :param sparse: is the tensor sparse 

116 :param mtype: kind of the result (see class @see cl OnnxKind) 

117 :param constraints: list of constraints applying on variables 

118 """ 

119 

120 def __init__(self, name, shape=None, dtype=None, sparse=False, 

121 mtype=OnnxKind.Tensor, constraints=None): 

122 if not isinstance(name, str): 

123 raise TypeError( # pragma: no cover 

124 "name must be a string not %r." % type(name)) 

125 if not isinstance(sparse, bool): 

126 raise TypeError( # pragma: no cover 

127 "sparse must be a boolean not %r." % sparse) 

128 if not isinstance(mtype, OnnxKind): 

129 raise TypeError( # pragma: no cover 

130 "mtype must be of type OnnxKind not %r." % type(mtype)) 

131 self.shape = list(shape) 

132 for i in range(0, len(self.shape)): # pylint: disable=C0200 

133 if shape[i] in ('', None, '?'): 

134 raise ValueError( # pragma: no cover 

135 "All dimensions must an int or a variable name, " 

136 "%s is not." % (shape, )) 

137 self.name = name 

138 self.mtype = mtype 

139 self.dtype = dtype 

140 self.sparse = sparse 

141 if constraints is None: 

142 self.constraints = ShapeConstraintList() 

143 elif isinstance(constraints, ShapeConstraintList): 

144 self.constraints = constraints 

145 else: 

146 raise TypeError( # pragma: no cover 

147 "constraints must be of type(ShapeConstraintList).") 

148 

149 def is_compatible(self, shape): 

150 """ 

151 Tells if this shape is compatible with the given tuple. 

152 

153 :param shape: tuple 

154 :return: boolean 

155 """ 

156 if isinstance(shape, numpy.ndarray): 

157 shape = shape.shape 

158 if all(map(lambda x: isinstance(x, int), self.shape)): 

159 return tuple(self.shape) == tuple(shape) 

160 raise NotImplementedError("%r ? %r" % (self, shape)) 

161 

162 def copy(self, deep=False): 

163 """ 

164 Returns a copy for the result. 

165 """ 

166 return ShapeResult(self.name, self.shape, self.dtype, self.sparse, 

167 self.mtype, self.constraints.copy(deep=deep)) 

168 

169 def __repr__(self): 

170 """ 

171 Usual 

172 """ 

173 if len(self.constraints) > 0: 

174 return "%s(%r, %r, %r, sparse=%r, mtype=%r, constraints=%r)" % ( 

175 self.__class__.__name__, self.name, self.shape, self.dtype, 

176 self.sparse, self.mtype, self.constraints) 

177 if self.mtype != OnnxKind.Tensor: 

178 return "%s(%r, %r, %r, sparse=%r, mtype=%r)" % ( 

179 self.__class__.__name__, self.name, self.shape, self.dtype, 

180 self.sparse, self.mtype) 

181 if self.sparse: 

182 return "%s(%r, %r, %r,sparse=%r)" % ( 

183 self.__class__.__name__, self.name, self.shape, self.dtype, 

184 self.sparse) 

185 return "%s(%r, %r, %r)" % ( 

186 self.__class__.__name__, self.name, self.shape, self.dtype) 

187 

188 def __eq__(self, shape): 

189 """ 

190 Tells if two shapes are identical. 

191 """ 

192 return (self.mtype == shape.mtype and self.shape == shape.shape and 

193 self.dtype == shape.dtype and self.sparse == shape.sparse) 

194 

195 def n_dims(self): 

196 """ 

197 Returns the number of dimensions if it is a tensor. 

198 Raises an exception otherwise. 

199 """ 

200 if self.mtype != OnnxKind.Tensor: 

201 raise ShapeInferenceException( # pragma: no cover 

202 "This shape is not a tensor %r." % self) 

203 return len(self.shape) 

204 

205 def merge(self, other_result): 

206 """ 

207 Merges constraints from *other_results* into *self*. 

208 """ 

209 if self.mtype != other_result.mtype: 

210 raise RuntimeError( # pragma: no cover 

211 "Unable to merge %r and %r." % (self, other_result)) 

212 if (len(self.shape) != 0 and len(other_result.shape) != 0 and 

213 len(self.shape) != len(other_result.shape)): 

214 raise RuntimeError( # pragma: no cover 

215 "Length mismatch, unable to merge %r and %r." % ( 

216 self, other_result)) 

217 updated = False 

218 if other_result.constraints is not None: 

219 for c in other_result.constraints: 

220 if c not in self.constraints: 

221 self.constraints.append(c) 

222 updated = True 

223 

224 if len(self.shape) == 0 and len(other_result.shape) > 0: 

225 # Then self.shape is unknown and the other one is. 

226 self.shape = other_result.shape.copy() 

227 return True 

228 

229 for a, b in zip(self.shape, other_result.shape): 

230 if a == b: 

231 continue 

232 if isinstance(a, int) and isinstance(b, int): 

233 raise RuntimeError( 

234 "Inconsistancy between %r and %r." % ( 

235 self, other_result)) 

236 elif isinstance(a, str): 

237 c = ShapeConstraint(a, {b}) 

238 if c not in self.constraints: 

239 updated = True 

240 self.constraints.append(c) 

241 elif isinstance(b, str): 

242 c = ShapeConstraint(b, {a}) 

243 if c not in self.constraints: 

244 updated = True 

245 self.constraints.append(c) 

246 else: 

247 raise NotImplementedError( # pragma: no cover 

248 "Merge not implemented between %r and %r." % ( 

249 self, other_result)) 

250 return updated 

251 

252 def resolve(self, variables): 

253 """ 

254 Results variables in a shape using values stored 

255 in *variables*. It does not copy any constraints. 

256 

257 :param variables: dictionary `{ name: values }` 

258 :return: new ShapeResult 

259 """ 

260 res = ShapeResult(self.name, shape=self.shape, dtype=self.dtype, 

261 sparse=self.sparse, mtype=self.mtype) 

262 for i in range(len(res.shape)): # pylint: disable=C0200 

263 v = res.shape[i] 

264 if isinstance(v, str): 

265 if v in variables: 

266 vals = variables[v] 

267 if vals is None: 

268 # size unknown 

269 continue 

270 if len(vals) == 1: 

271 res.shape[i] = list(vals)[0] 

272 else: 

273 res.shape[i] = set(vals) 

274 else: 

275 raise RuntimeError( # pragma: no cover 

276 "Unable to resolve shape %r due to missing " 

277 "%r." % (self, v)) 

278 return res 

279 

280 @staticmethod 

281 def broadcast(sh1, sh2, name=None): 

282 """ 

283 Broadcasts dimensions for an element wise operator. 

284 

285 :param sh1: ShapeResult 

286 :param sh2: ShapeResult 

287 :param name: name of the output ShapeResult 

288 :return: ShapeResult 

289 """ 

290 if not isinstance(sh1, ShapeResult): 

291 raise TypeError( # pragma: no cover 

292 "Unexpected type for sh1 %r." % type(sh1)) 

293 if not isinstance(sh2, ShapeResult): 

294 raise TypeError( # pragma: no cover 

295 "Unexpected type for sh2 %r." % type(sh2)) 

296 if sh1.mtype != OnnxKind.Tensor: 

297 raise TypeError( # pragma: no cover 

298 "sh1 must be a tensor not %r." % sh1.mtype) 

299 if sh2.mtype != OnnxKind.Tensor: 

300 raise TypeError( # pragma: no cover 

301 "sh2 must be a tensor not %r." % sh2.mtype) 

302 if sh1.n_dims() != sh2.n_dims(): 

303 if sh1.n_dims() == 1 and sh1.shape[0] == 1: 

304 return ShapeResult( 

305 name, sh2.shape, sh2.dtype, sh2.sparse, sh2.mtype) 

306 if sh2.n_dims() == 1 and sh2.shape[0] == 1: 

307 return ShapeResult( 

308 name, sh1.shape, sh1.dtype, sh1.sparse, sh1.mtype) 

309 raise ShapeInferenceException( # pragma: no cover 

310 "Broadcasting is only implemented for shape of the same " 

311 "size, shapes are %r and %r." % (sh1, sh2)) 

312 if sh1.dtype != sh2.dtype: 

313 raise ShapeInferenceException( # pragma: no cover 

314 "Cannot broadcast shapes %r and %r (dtypes)." 

315 "" % (sh1, sh2)) 

316 

317 constraints = ShapeConstraintList() 

318 shape = [] 

319 for a, b in zip(sh1.shape, sh2.shape): 

320 if isinstance(a, int) and isinstance(b, int): 

321 if a != b: 

322 if min(a, b) == 1: 

323 d = max(a, b) 

324 else: 

325 raise ShapeInferenceException( # pragma: no cover 

326 "Cannot broadcast shapes %r and %r (dimensions)." 

327 "" % (sh1, sh2)) 

328 else: 

329 d = a 

330 elif isinstance(a, int): 

331 if a != 1: 

332 d = a 

333 constraints.append(ShapeConstraint(b, {1, a})) 

334 else: 

335 d = b 

336 elif isinstance(b, int): 

337 if b != 1: 

338 d = b 

339 constraints.append(ShapeConstraint(a, {1, b})) 

340 else: 

341 d = a 

342 elif a == b: 

343 d = a 

344 else: 

345 raise ShapeInferenceException( # pragma: no cover 

346 "Cannot broadcast shapes %r and %r." % (sh1, sh2)) 

347 shape.append(d) 

348 if name in (None, ''): 

349 raise ValueError( # pragma: no cover 

350 "name cannot be empty.") 

351 res = ShapeResult(name, shape, sh1.dtype, sh1.sparse or sh2.sparse, 

352 sh1.mtype, constraints) 

353 return res