Coverage for mlprodict/onnxrt/ops_shape/shape_container.py: 92%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Class ShapeContainer
4"""
5import pprint
6from .shape_result import ShapeResult
9class ShapeContainer:
10 """
11 Stores all infered shapes as @see cl ShapeResult.
13 Attributes:
15 * `shapes`: dictionary `{ result name: ShapeResult }`
16 * `names`: some dimensions are unknown and represented as
17 variables, this dictionary keeps track of them
18 * `names_rev`: reverse dictionary of `names`
19 """
21 def __init__(self):
22 self.shapes = dict()
23 self.names = dict()
24 self.names_rev = dict()
26 def __repr__(self):
27 "usual"
28 return "%s()" % self.__class__.__name__
30 def __len__(self):
31 "usual"
32 return len(self.shapes)
34 def __getitem__(self, key):
35 "Retrieves one shape from its name."
36 return self.shapes[key]
38 def copy(self, deep=False):
39 "Makes a copy."
40 cont = ShapeContainer()
41 cont.shapes = {k: v.copy(deep=deep) for k, v in self.shapes.items()}
42 cont.names = self.names.copy()
43 cont.names_rev = {k: v.copy() for k, v in self.names_rev.items()}
44 return cont
46 def update(self, key, value):
47 """
48 Updates one shape. Returns True if the shape was different.
49 """
50 if not isinstance(key, str):
51 raise TypeError( # pragma: no cover
52 "key must be a string not %r." % type(key))
53 if not isinstance(value, ShapeResult):
54 raise TypeError( # pragma: no cover
55 "value must be a ShapeResult not %r." % type(key))
56 if key not in self.shapes:
57 self.shapes[key] = value
58 return True
59 r = self.shapes[key].merge(value)
60 return r
62 def __contains__(self, key):
63 "Operator in."
64 return key in self.shapes
66 def __str__(self):
67 """
68 Displays.
69 """
70 rows = ["ShapeContainer({"]
71 for k, v in self.shapes.items():
72 rows.append(" %r: %r" % (k, v))
73 rows.append("}, names={")
74 for k, v in self.names.items():
75 rows.append(" %r: %r" % (k, v))
76 cst = self.get_all_constraints()
77 if len(cst) > 0:
78 rows.append("}, constraint={")
79 for c, v in cst.items():
80 rows.append(" %r: %r" % (c, v))
81 rows.append("})")
82 else:
83 rows.append("})")
85 return "\n".join(rows)
87 def get_new_name(self, name, result_name, dim):
88 """
89 Returns a variable name when a dimension is not
90 specified.
91 """
92 if name is not None and not isinstance(name, str):
93 raise TypeError( # pragma: no cover
94 "name must be string not %r." % name)
95 if name is None:
96 name = ''
97 if name == '' or name not in self.names:
98 i = 0
99 new_name = "%s_%d" % (name, i)
100 while new_name in self.names:
101 i += 1
102 new_name = "%s_%d" % (name, i)
103 self.names[new_name] = (name, result_name, dim)
104 if name not in self.names_rev:
105 self.names_rev[name] = []
106 self.names_rev[name].append(new_name)
107 return new_name
108 val = self.names_rev[name]
109 if len(val) != 1:
110 raise RuntimeError( # pragma: no cover
111 "Name %r has more than one correspondance (%r)." % (
112 name, val))
113 return val[0]
115 def get_all_constraints(self):
116 """
117 Gathers all constraints.
118 """
119 cons = {}
120 for _, v in self.shapes.items():
121 if v.constraints is not None:
122 for c in v.constraints:
123 if c.name not in cons:
124 cons[c.name] = []
125 cons[c.name].append(c)
126 for _, v in cons.items():
127 if len(v) > 1:
128 v[0].merge(v[1:])
129 del v[1:]
130 return cons
132 def get(self):
133 """
134 Returns the value of attribute `resolved_`
135 (method `resolve()` must have been called first).
136 """
137 if not hasattr(self, 'resolved_') or self.resolved_ is None:
138 raise AttributeError( # pragma: no cover
139 "Attribute 'resolved_' is missing. You must run "
140 "method 'resolve()'.")
141 return self.resolved_
143 def resolve(self):
144 """
145 Resolves all constraints. It adds the attribute
146 `resolved_`.
147 """
148 def vars_in_values(values):
149 i_vals, s_vals = [], []
150 for v in values:
151 if isinstance(v, str):
152 s_vals.append(v)
153 else:
154 i_vals.append(v)
155 return set(i_vals), s_vals
157 variables = {}
158 for _, v in self.shapes.items():
159 for sh in v.shape:
160 if isinstance(sh, str):
161 variables[sh] = None
163 # first step: resolves all constraint with integer
164 dcsts = self.get_all_constraints()
165 csts = []
166 for li in dcsts.values():
167 csts.extend(li)
168 new_csts = []
169 for cst in csts:
170 if cst.name in variables and variables[cst.name] is None:
171 if all(map(lambda n: isinstance(n, int), cst.values)):
172 variables[cst.name] = cst.values.copy()
173 else:
174 new_csts.append(cst)
175 else:
176 raise RuntimeError( # pragma: no cover
177 "Unable to find any correspondance for variable %r "
178 "in %r." % (cst.name, ", ".join(sorted(variables))))
180 # second step: everything else, like a logic algorithm
181 dim_names = set()
182 csts = new_csts
183 updates = 1
184 while updates > 0 and len(new_csts) > 0:
185 updates = 0
186 new_csts = []
187 for cst in csts:
188 rvalues = variables[cst.name]
189 ivalues, lvars = vars_in_values(cst.values)
191 if len(lvars) > 0:
192 miss = 0
193 for lv in lvars:
194 if lv in variables and variables[lv] is not None:
195 ivalues |= variables[lv]
196 else:
197 miss += 1
199 if miss == 0:
200 # simple case: only integers
201 if rvalues is None:
202 inter = ivalues
203 else:
204 inter = rvalues.intersection(ivalues)
205 if len(inter) == 0:
206 raise RuntimeError( # pragma: no cover
207 "Resolution failed for variable %r, "
208 "current possibilities %r does not match "
209 "constraint %r." % (cst.name, rvalues, cst))
210 if rvalues is None or len(inter) < len(rvalues):
211 variables[cst.name] = inter
212 updates += 1
213 else:
214 continue
215 elif len(dim_names) > 0:
216 # more complex case: variables
217 if len(cst.values) == 1 and len(lvars) == 1:
218 # exact mapping between cst.name and lvars[0]
219 a, b = cst.name, lvars[0]
220 if variables[a] is None and variables[b] is not None:
221 if variables[b].intersection(dim_names):
222 variables[a] = variables[b]
223 updates += 1
224 continue
225 elif variables[b] is None and variables[a] is not None:
226 if variables[a].intersection(dim_names):
227 variables[b] = variables[a]
228 updates += 1
229 continue
231 new_csts.append(cst)
232 csts = new_csts
234 if len(new_csts) > 0 and updates == 0:
235 # It means that a dimension needs to be left unknown.
236 found = None
237 for k, v in variables.items():
238 if v is None:
239 found = k
240 if found is not None:
241 name = "d%d" % len(dim_names)
242 dim_names.add(name)
243 variables[found] = {name}
244 updates += 1
245 else:
246 raise RuntimeError( # pragma: no cover
247 "Inconsistency in %r with\n%r" % (
248 self, variables))
250 # final
251 results = {}
252 for k, v in self.shapes.items():
253 try:
254 results[k] = v.resolve(variables)
255 except RuntimeError as e: # pragma: no cover
256 raise RuntimeError(
257 "Unable to resolve shapes and constraints:\n%s"
258 "" % pprint.pformat(self.shapes)) from e
259 self.resolved_ = results
260 return self.resolved_