Coverage for mlprodict/onnxrt/ops_shape/shape_container.py: 92%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

156 statements  

1""" 

2@file 

3@brief Class ShapeContainer 

4""" 

5import pprint 

6from .shape_result import ShapeResult 

7 

8 

9class ShapeContainer: 

10 """ 

11 Stores all infered shapes as @see cl ShapeResult. 

12 

13 Attributes: 

14 

15 * `shapes`: dictionary `{ result name: ShapeResult }` 

16 * `names`: some dimensions are unknown and represented as 

17 variables, this dictionary keeps track of them 

18 * `names_rev`: reverse dictionary of `names` 

19 """ 

20 

21 def __init__(self): 

22 self.shapes = dict() 

23 self.names = dict() 

24 self.names_rev = dict() 

25 

26 def __repr__(self): 

27 "usual" 

28 return "%s()" % self.__class__.__name__ 

29 

30 def __len__(self): 

31 "usual" 

32 return len(self.shapes) 

33 

34 def __getitem__(self, key): 

35 "Retrieves one shape from its name." 

36 return self.shapes[key] 

37 

38 def copy(self, deep=False): 

39 "Makes a copy." 

40 cont = ShapeContainer() 

41 cont.shapes = {k: v.copy(deep=deep) for k, v in self.shapes.items()} 

42 cont.names = self.names.copy() 

43 cont.names_rev = {k: v.copy() for k, v in self.names_rev.items()} 

44 return cont 

45 

46 def update(self, key, value): 

47 """ 

48 Updates one shape. Returns True if the shape was different. 

49 """ 

50 if not isinstance(key, str): 

51 raise TypeError( # pragma: no cover 

52 "key must be a string not %r." % type(key)) 

53 if not isinstance(value, ShapeResult): 

54 raise TypeError( # pragma: no cover 

55 "value must be a ShapeResult not %r." % type(key)) 

56 if key not in self.shapes: 

57 self.shapes[key] = value 

58 return True 

59 r = self.shapes[key].merge(value) 

60 return r 

61 

62 def __contains__(self, key): 

63 "Operator in." 

64 return key in self.shapes 

65 

66 def __str__(self): 

67 """ 

68 Displays. 

69 """ 

70 rows = ["ShapeContainer({"] 

71 for k, v in self.shapes.items(): 

72 rows.append(" %r: %r" % (k, v)) 

73 rows.append("}, names={") 

74 for k, v in self.names.items(): 

75 rows.append(" %r: %r" % (k, v)) 

76 cst = self.get_all_constraints() 

77 if len(cst) > 0: 

78 rows.append("}, constraint={") 

79 for c, v in cst.items(): 

80 rows.append(" %r: %r" % (c, v)) 

81 rows.append("})") 

82 else: 

83 rows.append("})") 

84 

85 return "\n".join(rows) 

86 

87 def get_new_name(self, name, result_name, dim): 

88 """ 

89 Returns a variable name when a dimension is not 

90 specified. 

91 """ 

92 if name is not None and not isinstance(name, str): 

93 raise TypeError( # pragma: no cover 

94 "name must be string not %r." % name) 

95 if name is None: 

96 name = '' 

97 if name == '' or name not in self.names: 

98 i = 0 

99 new_name = "%s_%d" % (name, i) 

100 while new_name in self.names: 

101 i += 1 

102 new_name = "%s_%d" % (name, i) 

103 self.names[new_name] = (name, result_name, dim) 

104 if name not in self.names_rev: 

105 self.names_rev[name] = [] 

106 self.names_rev[name].append(new_name) 

107 return new_name 

108 val = self.names_rev[name] 

109 if len(val) != 1: 

110 raise RuntimeError( # pragma: no cover 

111 "Name %r has more than one correspondance (%r)." % ( 

112 name, val)) 

113 return val[0] 

114 

115 def get_all_constraints(self): 

116 """ 

117 Gathers all constraints. 

118 """ 

119 cons = {} 

120 for _, v in self.shapes.items(): 

121 if v.constraints is not None: 

122 for c in v.constraints: 

123 if c.name not in cons: 

124 cons[c.name] = [] 

125 cons[c.name].append(c) 

126 for _, v in cons.items(): 

127 if len(v) > 1: 

128 v[0].merge(v[1:]) 

129 del v[1:] 

130 return cons 

131 

132 def get(self): 

133 """ 

134 Returns the value of attribute `resolved_` 

135 (method `resolve()` must have been called first). 

136 """ 

137 if not hasattr(self, 'resolved_') or self.resolved_ is None: 

138 raise AttributeError( # pragma: no cover 

139 "Attribute 'resolved_' is missing. You must run " 

140 "method 'resolve()'.") 

141 return self.resolved_ 

142 

143 def resolve(self): 

144 """ 

145 Resolves all constraints. It adds the attribute 

146 `resolved_`. 

147 """ 

148 def vars_in_values(values): 

149 i_vals, s_vals = [], [] 

150 for v in values: 

151 if isinstance(v, str): 

152 s_vals.append(v) 

153 else: 

154 i_vals.append(v) 

155 return set(i_vals), s_vals 

156 

157 variables = {} 

158 for _, v in self.shapes.items(): 

159 for sh in v.shape: 

160 if isinstance(sh, str): 

161 variables[sh] = None 

162 

163 # first step: resolves all constraint with integer 

164 dcsts = self.get_all_constraints() 

165 csts = [] 

166 for li in dcsts.values(): 

167 csts.extend(li) 

168 new_csts = [] 

169 for cst in csts: 

170 if cst.name in variables and variables[cst.name] is None: 

171 if all(map(lambda n: isinstance(n, int), cst.values)): 

172 variables[cst.name] = cst.values.copy() 

173 else: 

174 new_csts.append(cst) 

175 else: 

176 raise RuntimeError( # pragma: no cover 

177 "Unable to find any correspondance for variable %r " 

178 "in %r." % (cst.name, ", ".join(sorted(variables)))) 

179 

180 # second step: everything else, like a logic algorithm 

181 dim_names = set() 

182 csts = new_csts 

183 updates = 1 

184 while updates > 0 and len(new_csts) > 0: 

185 updates = 0 

186 new_csts = [] 

187 for cst in csts: 

188 rvalues = variables[cst.name] 

189 ivalues, lvars = vars_in_values(cst.values) 

190 

191 if len(lvars) > 0: 

192 miss = 0 

193 for lv in lvars: 

194 if lv in variables and variables[lv] is not None: 

195 ivalues |= variables[lv] 

196 else: 

197 miss += 1 

198 

199 if miss == 0: 

200 # simple case: only integers 

201 if rvalues is None: 

202 inter = ivalues 

203 else: 

204 inter = rvalues.intersection(ivalues) 

205 if len(inter) == 0: 

206 raise RuntimeError( # pragma: no cover 

207 "Resolution failed for variable %r, " 

208 "current possibilities %r does not match " 

209 "constraint %r." % (cst.name, rvalues, cst)) 

210 if rvalues is None or len(inter) < len(rvalues): 

211 variables[cst.name] = inter 

212 updates += 1 

213 else: 

214 continue 

215 elif len(dim_names) > 0: 

216 # more complex case: variables 

217 if len(cst.values) == 1 and len(lvars) == 1: 

218 # exact mapping between cst.name and lvars[0] 

219 a, b = cst.name, lvars[0] 

220 if variables[a] is None and variables[b] is not None: 

221 if variables[b].intersection(dim_names): 

222 variables[a] = variables[b] 

223 updates += 1 

224 continue 

225 elif variables[b] is None and variables[a] is not None: 

226 if variables[a].intersection(dim_names): 

227 variables[b] = variables[a] 

228 updates += 1 

229 continue 

230 

231 new_csts.append(cst) 

232 csts = new_csts 

233 

234 if len(new_csts) > 0 and updates == 0: 

235 # It means that a dimension needs to be left unknown. 

236 found = None 

237 for k, v in variables.items(): 

238 if v is None: 

239 found = k 

240 if found is not None: 

241 name = "d%d" % len(dim_names) 

242 dim_names.add(name) 

243 variables[found] = {name} 

244 updates += 1 

245 else: 

246 raise RuntimeError( # pragma: no cover 

247 "Inconsistency in %r with\n%r" % ( 

248 self, variables)) 

249 

250 # final 

251 results = {} 

252 for k, v in self.shapes.items(): 

253 try: 

254 results[k] = v.resolve(variables) 

255 except RuntimeError as e: # pragma: no cover 

256 raise RuntimeError( 

257 "Unable to resolve shapes and constraints:\n%s" 

258 "" % pprint.pformat(self.shapes)) from e 

259 self.resolved_ = results 

260 return self.resolved_