# SPDX-License-Identifier: Apache-2.0
from collections import namedtuple
from typing import Any, Dict, NewType, Optional, Sequence, Tuple, Type
import numpy # type: ignore
import onnx.checker
import onnx.onnx_cpp2py_export.checker as c_checker
from onnx import IR_VERSION, ModelProto, NodeProto
[docs]class DeviceType:
_Type = NewType("_Type", int)
CPU: _Type = _Type(0)
CUDA: _Type = _Type(1)
[docs]class Device:
"""
Describes device type and device id
syntax: device_type:device_id(optional)
example: 'CPU', 'CUDA', 'CUDA:1'
"""
def __init__(self, device: str) -> None:
options = device.split(":")
self.type = getattr(DeviceType, options[0])
self.device_id = 0
if len(options) > 1:
self.device_id = int(options[1])
def namedtupledict(
typename: str, field_names: Sequence[str], *args: Any, **kwargs: Any
) -> Type[Tuple[Any, ...]]:
field_names_map = {n: i for i, n in enumerate(field_names)}
# Some output names are invalid python identifier, e.g. "0"
kwargs.setdefault("rename", True)
data = namedtuple(typename, field_names, *args, **kwargs) # type: ignore
def getitem(self: Any, key: Any) -> Any:
if isinstance(key, str):
key = field_names_map[key]
return super(type(self), self).__getitem__(key) # type: ignore
setattr(data, "__getitem__", getitem)
return data
[docs]class BackendRep:
def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]:
pass
[docs]class Backend:
@classmethod
def is_compatible(
cls, model: ModelProto, device: str = "CPU", **kwargs: Any
) -> bool:
# Return whether the model is compatible with the backend.
return True
@classmethod
def prepare(
cls, model: ModelProto, device: str = "CPU", **kwargs: Any
) -> Optional[BackendRep]:
# TODO Remove Optional from return type
onnx.checker.check_model(model)
return None
@classmethod
def run_model(
cls, model: ModelProto, inputs: Any, device: str = "CPU", **kwargs: Any
) -> Tuple[Any, ...]:
backend = cls.prepare(model, device, **kwargs)
assert backend is not None
return backend.run(inputs)
[docs] @classmethod
def run_node(
cls,
node: NodeProto,
inputs: Any,
device: str = "CPU",
outputs_info: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]] = None,
**kwargs: Dict[str, Any],
) -> Optional[Tuple[Any, ...]]:
"""Simple run one operator and return the results.
Args:
outputs_info: a list of tuples, which contains the element type and
shape of each output. First element of the tuple is the dtype, and
the second element is the shape. More use case can be found in
https://github.com/onnx/onnx/blob/main/onnx/backend/test/runner/__init__.py
"""
# TODO Remove Optional from return type
if "opset_version" in kwargs:
special_context = c_checker.CheckerContext()
special_context.ir_version = IR_VERSION
special_context.opset_imports = {"": kwargs["opset_version"]} # type: ignore
onnx.checker.check_node(node, special_context)
else:
onnx.checker.check_node(node)
return None
[docs] @classmethod
def supports_device(cls, device: str) -> bool:
"""
Checks whether the backend is compiled with particular device support.
In particular it's used in the testing suite.
"""
return True