Source code for onnx

# SPDX-License-Identifier: Apache-2.0

# isort:skip_file
import os

from .onnx_cpp2py_export import ONNX_ML  # noqa
from onnx.external_data_helper import (
    load_external_data_for_model,
    write_external_data_tensors,
    convert_model_to_external_data,
)
from .onnx_pb import *  # noqa
from .onnx_operators_pb import *  # noqa
from .onnx_data_pb import *  # noqa
from .version import version as __version__  # noqa

# Import common subpackages so they're available when you 'import onnx'
import onnx.checker  # noqa
import onnx.defs  # noqa
import onnx.helper  # noqa
import onnx.utils  # noqa
import onnx.compose  # noqa

import google.protobuf.message

from typing import Union, IO, Optional, cast, TypeVar, Any


# f should be either readable or a file path
def _load_bytes(f: Union[IO[bytes], str]) -> bytes:
    if hasattr(f, "read") and callable(cast(IO[bytes], f).read):
        s = cast(IO[bytes], f).read()
    else:
        with open(cast(str, f), "rb") as readable:
            s = readable.read()
    return s


# content should be bytes,
# f should be either writable or a file path
def _save_bytes(content: bytes, f: Union[IO[bytes], str]) -> None:
    if hasattr(f, "write") and callable(cast(IO[bytes], f).write):
        cast(IO[bytes], f).write(content)
    else:
        with open(cast(str, f), "wb") as writable:
            writable.write(content)


# f should be either a readable file or a file path
def _get_file_path(f: Union[IO[bytes], str]) -> Optional[str]:
    if isinstance(f, str):
        return os.path.abspath(f)
    if hasattr(f, "name"):
        return os.path.abspath(f.name)
    return None


def _serialize(proto: Union[bytes, google.protobuf.message.Message]) -> bytes:
    """
    Serialize a in-memory proto to bytes

    Arguments:
        proto: a in-memory proto, such as a ModelProto, TensorProto, etc

    Returns:
        Serialized proto in bytes
    """
    if isinstance(proto, bytes):
        return proto
    if hasattr(proto, "SerializeToString") and callable(proto.SerializeToString):
        try:
            result = proto.SerializeToString()
        except ValueError as e:
            if proto.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
                raise ValueError(
                    "The proto size is larger than the 2 GB limit. "
                    "Please use save_as_external_data to save tensors separately from the model file."
                ) from e
            raise
        return result
    raise TypeError(
        f"No SerializeToString method is detected. Neither proto is a str.\ntype is {type(proto)}"
    )


_Proto = TypeVar("_Proto", bound=google.protobuf.message.Message)


def _deserialize(s: bytes, proto: _Proto) -> _Proto:
    """
    Parse bytes into a in-memory proto

    Arguments:
        s: bytes containing serialized proto
        proto: a in-memory proto object

    Returns:
        The proto instance filled in by s
    """
    if not isinstance(s, bytes):
        raise ValueError(f"Parameter s must be bytes, but got type: {type(s)}")

    if not (hasattr(proto, "ParseFromString") and callable(proto.ParseFromString)):
        raise ValueError(
            f"No ParseFromString method is detected. Type is {type(proto)}"
        )

    decoded = cast(Optional[int], proto.ParseFromString(s))
    if decoded is not None and decoded != len(s):
        raise google.protobuf.message.DecodeError(
            f"Protobuf decoding consumed too few bytes: {decoded} out of {len(s)}"
        )
    return proto


def load_model(
    f: Union[IO[bytes], str],
    format: Optional[Any] = None,
    load_external_data: bool = True,
) -> ModelProto:
    """
    Loads a serialized ModelProto into memory
    load_external_data is true if the external data under the same directory of the model and load the external data
    If not, users need to call load_external_data_for_model with directory to load

    Arguments:
        f: can be a file-like object (has "read" function) or a string containing a file name
        format: for future use

    Returns:
        Loaded in-memory ModelProto
    """
    s = _load_bytes(f)
    model = load_model_from_string(s, format=format)

    if load_external_data:
        model_filepath = _get_file_path(f)
        if model_filepath:
            base_dir = os.path.dirname(model_filepath)
            load_external_data_for_model(model, base_dir)

    return model


def load_tensor(f: Union[IO[bytes], str], format: Optional[Any] = None) -> TensorProto:
    """
    Loads a serialized TensorProto into memory

    Arguments:
        f: can be a file-like object (has "read" function) or a string containing a file name
        format: for future use

    Returns:
        Loaded in-memory TensorProto
    """
    s = _load_bytes(f)
    return load_tensor_from_string(s, format=format)


[docs]def load_model_from_string(s: bytes, format: Optional[Any] = None) -> ModelProto: """ Loads a binary string (bytes) that contains serialized ModelProto Arguments: s: a string, which contains serialized ModelProto format: for future use Returns: Loaded in-memory ModelProto """ return _deserialize(s, ModelProto())
[docs]def load_tensor_from_string(s: bytes, format: Optional[Any] = None) -> TensorProto: """ Loads a binary string (bytes) that contains serialized TensorProto Arguments: s: a string, which contains serialized TensorProto format: for future use Returns: Loaded in-memory TensorProto """ return _deserialize(s, TensorProto())
def save_model( proto: Union[ModelProto, bytes], f: Union[IO[bytes], str], format: Optional[Any] = None, save_as_external_data: bool = False, all_tensors_to_one_file: bool = True, location: Optional[str] = None, size_threshold: int = 1024, convert_attribute: bool = False, ) -> None: """ Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving. Arguments: proto: should be a in-memory ModelProto f: can be a file-like object (has "write" function) or a string containing a file name format for future use all_tensors_to_one_file: If true, save all tensors to one external file specified by location. If false, save each tensor to a file named with the tensor name. location: specify the external file that all tensors to save to. If not specified, will use the model name. size_threshold: Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted to external data. To convert every tensor with raw data to external data set size_threshold=0. convert_attribute: If true, convert all tensors to external data If false, convert only non-attribute tensors to external data """ if isinstance(proto, bytes): proto = _deserialize(proto, ModelProto()) if save_as_external_data: convert_model_to_external_data( proto, all_tensors_to_one_file, location, size_threshold, convert_attribute ) model_filepath = _get_file_path(f) if model_filepath: basepath = os.path.dirname(model_filepath) proto = write_external_data_tensors(proto, basepath) s = _serialize(proto) _save_bytes(s, f) def save_tensor(proto: TensorProto, f: Union[IO[bytes], str]) -> None: """ Saves the TensorProto to the specified path. Arguments: proto: should be a in-memory TensorProto f: can be a file-like object (has "write" function) or a string containing a file name format: for future use """ s = _serialize(proto) _save_bytes(s, f) # For backward compatibility load = load_model load_from_string = load_model_from_string save = save_model