Source code for embedl_hub._internal.core.invoke.onnxruntime

# Copyright (C) 2026 Embedl AB

"""ONNX Runtime invoker component.

Supports model invocation via:
- **``qai_hub``** devices: Qualcomm AI Hub cloud inference.
- **``embedl_onnxruntime``** devices: Remote invocation using
  ``embedl-onnxruntime`` over SSH.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory

import numpy as np
import qai_hub as hub

from embedl_hub._internal.core.compile.onnxruntime import (
    ONNXRuntimeCompiledModel,
)
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import ComponentOutput
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.device.abc import CommandRunner
from embedl_hub._internal.core.device.commands import run_remote_command
from embedl_hub._internal.core.device.config.embedl_onnxruntime import (
    EmbedlONNXRuntimeConfig,
)
from embedl_hub._internal.core.device.transfer import get_directory
from embedl_hub._internal.core.invoke import InvokeError
from embedl_hub._internal.core.types import LocalPath, RemotePath
from embedl_hub._internal.core.utils.onnx_utils import (
    maybe_package_onnx_folder_to_file,
)
from embedl_hub._internal.core.utils.qai_hub_utils import (
    log_runtime_info,
    resolve_qai_hub_device,
)
from embedl_hub._internal.tracking.rest_api import RunType
from embedl_hub._internal.tracking.run_log import LoggedArtifact


def _save_input_npz(
    input_data: dict[str, np.ndarray],
    local_dir: LocalPath,
) -> LocalPath:
    """Save input data as a compressed ``.npz`` file.

    :param input_data: Mapping of input names to NumPy arrays.
    :param local_dir: Local directory to write the file into.
    :return: Path to the saved ``.npz`` file.
    """
    npz_path = local_dir / "input.npz"
    np.savez_compressed(npz_path, **input_data)
    return npz_path


def _load_output_dir(output_dir: Path) -> dict[str, np.ndarray]:
    """Load inference output from a directory.

    Supports two formats produced by ``embedl-onnxruntime``:

    * **``.npz`` mode** (``--use-npz``): one or more compressed
      ``.npz`` files, each containing named arrays keyed by output
      tensor name.  All arrays from every ``.npz`` file are merged
      into a single dict.
    * **``.npy`` mode** (default): one ``.npy`` file per output,
      keyed by file stem.

    :param output_dir: Directory containing the output files.
    :return: Mapping of output names to NumPy arrays.
    :raises InvokeError: If no output files are found.
    """
    # Prefer .npz files (written when --use-npz is passed)
    npz_files = sorted(output_dir.glob("*.npz"))
    if npz_files:
        result: dict[str, np.ndarray] = {}
        for npz_file in npz_files:
            with np.load(npz_file) as data:
                result.update(data)
        return result

    # Fall back to individual .npy files
    npy_files = sorted(output_dir.glob("*.npy"))
    if npy_files:
        return {f.stem: np.load(f) for f in npy_files}

    raise InvokeError(
        f"No output files (.npz or .npy) found in '{output_dir}'."
    )


[docs] @dataclass(frozen=True) class ONNXRuntimeInvocationResult(ComponentOutput): """Output from the ONNXRuntimeInvoker component. :param output: The output data from the model invocation, mapping output tensor names to NumPy arrays. :param output_file: The artifact containing the serialised output ``.npz`` file. """ output: dict[str, np.ndarray] output_file: LoggedArtifact
[docs] class ONNXRuntimeInvoker(Component): """Component that runs inference on ONNX models using ``embedl-onnxruntime``. Runs ``embedl-onnxruntime run-inference`` on a remote device over SSH to execute inference on an ONNX model and exports the output tensors. Device-specific parameters (``embedl_onnxruntime_path``, ``cli_args``) are configured via :class:`~embedl_hub._internal.core.device.config.embedl_onnxruntime.EmbedlONNXRuntimeConfig` on the device. """ run_type = RunType.INFERENCE def __init__( self, *, name: str | None = None, device: str | None = None, use_compiled_names: bool = False, ) -> None: """Create a new ONNX Runtime invoker. Parameters passed here become the defaults for every call. They can be overridden per call by passing them to :meth:`run` directly. :param name: Display name for this component instance. :param device: Name of the device to use from the context. When ``None``, auto-selects the first compatible device. :param use_compiled_names: If ``False`` (default) and the compiler renamed input/output tensors, both original and compiled names are accepted for inputs, and outputs are returned with original names. Set to ``True`` to skip all remapping and only accept compiled names. *Only used on* ``qai_hub`` *devices.* """ super().__init__( name=name, device=device, use_compiled_names=use_compiled_names, )
[docs] def run( self, ctx: HubContext, model: ONNXRuntimeCompiledModel, input_data: dict[str, np.ndarray], *, device: str | None = None, use_compiled_names: bool = False, ) -> ONNXRuntimeInvocationResult: """Run inference on an ONNX model via ``embedl-onnxruntime``. Keyword arguments override the defaults set in the constructor. If a keyword argument is not provided here, the value from the constructor is used. :param ctx: The execution context with device configuration. :param model: An :class:`ONNXRuntimeCompiledModel` whose ``path`` artifact points to an ONNX model. :param input_data: Dictionary mapping input tensor names to NumPy arrays. Uploaded as a compressed ``.npz`` file. :param use_compiled_names: If ``False`` (default) and the compiler renamed input/output tensors, both original and compiled names are accepted for inputs, and outputs are returned with original names. Set to ``True`` to skip all remapping and only accept compiled names. *Only used on* ``qai_hub`` *devices.* :return: An :class:`ONNXRuntimeInvocationResult` with the output artifact. """ raise NoProviderError # Replaced by the component system.
# --------------------------------------------------------------------- # # embedl_onnxruntime provider # --------------------------------------------------------------------- # @ONNXRuntimeInvoker.provider(ProviderType.EMBEDL_ONNXRUNTIME) def _invoke_ort_cli( ctx: HubContext, model: ONNXRuntimeCompiledModel, input_data: dict[str, np.ndarray], *, device: str | None = None, use_compiled_names: bool = False, ) -> ONNXRuntimeInvocationResult: """Run inference on an ONNX model via ``embedl-onnxruntime``. All keyword arguments are pre-resolved by the component system. :param ctx: The execution context with device configuration. :param model: An :class:`ONNXRuntimeCompiledModel`. :param input_data: Mapping of input names to NumPy arrays. :param device: Name of the target device. :param use_compiled_names: Accepted for API consistency but has no effect — the ``embedl-onnxruntime`` provider does not rename tensors. :return: An :class:`ONNXRuntimeInvocationResult`. """ # -- validate device ------------------------------------------------ assert device is not None dev = ctx.devices[device] runner = dev.runner if runner is None: raise ValueError( "embedl_onnxruntime devices require a device with a " "command runner." ) remote_artifact_dir = dev.artifact_dir if remote_artifact_dir is None: raise ValueError("No remote artifact directory available.") # -- resolve model path (auto-transfer) ---------------------------- remote_model = model.path.to_remote(ctx, device_name=device).remote() # -- resolve provider config ---------------------------------------- cfg = dev.get_provider_config(EmbedlONNXRuntimeConfig, ONNXRuntimeInvoker) if cfg is None: cfg = EmbedlONNXRuntimeConfig() ort_path = str(cfg.embedl_onnxruntime_path) effective_cli_args: list[str] = list(cfg.cli_args) # -- log input artifact --------------------------------------------- ctx.client.log_artifact( model.path, name="input", file_name=f"input_{model.path.file_name}", ctx=ctx, ) # -- invoke -------------------------------------------------------- output_data = _run_remote_invocation( ctx, runner=runner, remote_model=remote_model, input_data=input_data, remote_artifact_dir=remote_artifact_dir, ort_path=ort_path, cli_args=effective_cli_args, ) # -- save and log output artifact ----------------------------------- output_npz_path = ctx.artifact_dir / "output.npz" np.savez_compressed(output_npz_path, **output_data) ctx.client.log_artifact(output_npz_path, name="output_file") return ONNXRuntimeInvocationResult.from_current_run( ctx, output=output_data ) # --------------------------------------------------------------------- # # QAI Hub provider # --------------------------------------------------------------------- # @ONNXRuntimeInvoker.provider(ProviderType.QAI_HUB) def _invoke_ort_qai_hub( ctx: HubContext, model: ONNXRuntimeCompiledModel, input_data: dict[str, np.ndarray], *, device: str | None = None, use_compiled_names: bool = False, ) -> ONNXRuntimeInvocationResult: """Run inference on an ONNX model using Qualcomm AI Hub. Submits an inference job to QAI Hub and downloads the output data. The ``cli_args`` parameter is ignored — QAI Hub controls the inference configuration. All keyword arguments are pre-resolved by the component system. :param ctx: The execution context with device configuration. :param model: An :class:`ONNXRuntimeCompiledModel`. :param input_data: Mapping of input names to NumPy arrays. :param device: Name of the target device. :param use_compiled_names: If ``False`` (default), input names are remapped to the compiled model's names before submission and output names are mapped back to the originals. :return: An :class:`ONNXRuntimeInvocationResult`. """ # -- resolve model path (auto-transfer to local) ------------------- model_local_path = model.path.to_local(ctx).local() input_name_mapping = model.input_name_mapping output_name_mapping = model.output_name_mapping # -- validate device ------------------------------------------------ hub_device, device_name = resolve_qai_hub_device(ctx, device) # -- remap input names if needed ------------------------------------ if not use_compiled_names and input_name_mapping: input_data = _remap_input_names(input_data, input_name_mapping) # -- log input artifact --------------------------------------------- ctx.client.log_artifact( model.path, name="input", file_name=f"input_{model.path.file_name}", ctx=ctx, ) # -- invoke via QAI Hub -------------------------------------------- output_data = _invoke_via_qai_hub( ctx, model_local_path, hub_device, input_data ) # -- remap output names back to originals --------------------------- if not use_compiled_names and output_name_mapping: reverse_output = {v: k for k, v in output_name_mapping.items()} output_data = { reverse_output.get(name, name): array for name, array in output_data.items() } # -- save and log output artifact ----------------------------------- output_npz_path = ctx.artifact_dir / "output.npz" np.savez_compressed(output_npz_path, **output_data) ctx.client.log_artifact(output_npz_path, name="output_file") return ONNXRuntimeInvocationResult.from_current_run( ctx, output=output_data ) # -- QAI Hub helpers ------------------------------------------------------- def _remap_input_names( input_data: dict[str, np.ndarray], name_mapping: dict[str, str], ) -> dict[str, np.ndarray]: """Remap input tensor names from original to compiled names. :param input_data: Original input data keyed by pre-compile names. :param name_mapping: Mapping from original to compiled names. :return: New dict with keys replaced according to `name_mapping`. :raises InvokeError: If an input name is not found in the mapping and does not match any compiled name either. """ compiled_names = set(name_mapping.values()) remapped: dict[str, np.ndarray] = {} for name, array in input_data.items(): if name in name_mapping: remapped[name_mapping[name]] = array elif name in compiled_names: # Already using compiled names — pass through. remapped[name] = array else: raise InvokeError( f"Input name '{name}' not found in the name mapping " f"and is not a compiled model input name. Known " f"original names: {sorted(name_mapping.keys())}. " f"Compiled names: {sorted(compiled_names)}." ) return remapped def _invoke_via_qai_hub( ctx: HubContext, model_path: Path, hub_device: hub.Device, input_data: dict[str, np.ndarray], ) -> dict[str, np.ndarray]: """Submit an inference job to Qualcomm AI Hub and download results. :param ctx: The execution context. :param model_path: Local path to the ONNX model. :param hub_device: Target QAI Hub device. :param input_data: Mapping of input names to NumPy arrays. :return: Mapping of output names to NumPy arrays. :raises InvokeError: If the inference job fails. """ # QAI Hub expects inputs as dict[str, list[np.ndarray]] hub_inputs: dict[str, list[np.ndarray]] = { name: [array] for name, array in input_data.items() } # Package directory models into a single .onnx file if needed with TemporaryDirectory() as tmpdir: packaged_model = maybe_package_onnx_folder_to_file(model_path, tmpdir) # Submit inference job try: job: hub.InferenceJob = hub.submit_inference_job( model=packaged_model, device=hub_device, inputs=hub_inputs, ) except Exception as error: raise InvokeError( f"Failed to submit inference job to Qualcomm AI Hub: {error}" ) from error ctx.client.log_param("$qai_hub_job_id", job.job_id) ctx.client.log_param("qai_hub_inference_job_id", job.job_id) # Download inference output (blocks until the job finishes) try: output_dataset = job.download_output_data() except Exception as error: raise InvokeError( f"Failed to download inference output from Qualcomm AI Hub: {error}" ) from error if output_dataset is None: raise InvokeError( f"Inference job returned no output data. Job ID: {job.job_id}" ) # Parse and log runtime info (safe now — the job has finished) log_runtime_info(ctx.client, job.job_id, error_class=InvokeError) # download_output_data() returns Mapping[str, list[np.ndarray]]. # Flatten to dict[str, np.ndarray] by taking the first (and only) # sample from each output, matching the single-sample input we sent. output_data: dict[str, np.ndarray] = {} for name, arrays in output_dataset.items(): if isinstance(arrays, list) and len(arrays) > 0: output_data[name] = arrays[0] else: output_data[name] = np.asarray(arrays) return output_data # --------------------------------------------------------------------- # # internal helpers # --------------------------------------------------------------------- # def _run_remote_invocation( ctx: HubContext, *, runner: CommandRunner, remote_model: RemotePath, input_data: dict[str, np.ndarray], remote_artifact_dir: RemotePath, ort_path: str, cli_args: list[str], ) -> dict[str, np.ndarray]: """Run ``embedl-onnxruntime run-inference`` on a remote device and download + parse the output. :param ctx: The execution context. :param runner: The device command runner. :param remote_model: Remote path to the ONNX model on the device. :param input_data: Mapping of input names to NumPy arrays. :param remote_artifact_dir: Remote directory for artifacts. :param ort_path: Path to ``embedl-onnxruntime`` on the remote device. :param cli_args: Extra CLI arguments. :return: Mapping of output names to NumPy arrays. :raises InvokeError: If the remote command fails. """ local_artifact_dir = ctx.artifact_dir # Save and upload input data as .npz local_input_npz = _save_input_npz(input_data, local_artifact_dir) remote_input_dir = remote_artifact_dir / "input" runner.run(["mkdir", "-p", str(remote_input_dir)], hide=True) runner.put(LocalPath(local_input_npz), remote_input_dir / "data.npz") # Build remote output directory remote_output_dir = remote_artifact_dir / "output" runner.run(["mkdir", "-p", str(remote_output_dir)], hide=True) # Assemble the embedl-onnxruntime command run_args: list[str] = [ ort_path, "run-inference", "--model", str(remote_model), "--input-dir", str(remote_input_dir), "--output-dir", str(remote_output_dir), "--use-npz", ] # Append extra CLI arguments run_args.extend(cli_args) # Log parameters ctx.client.log_param("$runtime", "onnx") ctx.client.log_param("embedl_onnxruntime_path", ort_path) if cli_args: ctx.client.log_param("cli_args", " ".join(cli_args)) # Execute on the remote device result = run_remote_command( runner, run_args, local_artifact_dir, error_cls=InvokeError, error_message="ONNX Runtime invocation failed on the remote device.", ) # Download the output directory from the remote device local_output_dir = local_artifact_dir / "output" get_directory(runner, remote_output_dir, local_output_dir) # Parse and return the output data return _load_output_dir(local_output_dir)