# Copyright (C) 2026 Embedl AB
"""ONNX Runtime invoker component.
Supports model invocation via:
- **``qai_hub``** devices: Qualcomm AI Hub cloud inference.
- **``embedl_onnxruntime``** devices: Remote invocation using
``embedl-onnxruntime`` over SSH.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
import numpy as np
import qai_hub as hub
from embedl_hub._internal.core.compile.onnxruntime import (
ONNXRuntimeCompiledModel,
)
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import ComponentOutput
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.device.abc import CommandRunner
from embedl_hub._internal.core.device.commands import run_remote_command
from embedl_hub._internal.core.device.config.embedl_onnxruntime import (
EmbedlONNXRuntimeConfig,
)
from embedl_hub._internal.core.device.transfer import get_directory
from embedl_hub._internal.core.invoke import InvokeError
from embedl_hub._internal.core.types import LocalPath, RemotePath
from embedl_hub._internal.core.utils.onnx_utils import (
maybe_package_onnx_folder_to_file,
)
from embedl_hub._internal.core.utils.qai_hub_utils import (
log_runtime_info,
resolve_qai_hub_device,
)
from embedl_hub._internal.tracking.rest_api import RunType
from embedl_hub._internal.tracking.run_log import LoggedArtifact
def _save_input_npz(
input_data: dict[str, np.ndarray],
local_dir: LocalPath,
) -> LocalPath:
"""Save input data as a compressed ``.npz`` file.
:param input_data: Mapping of input names to NumPy arrays.
:param local_dir: Local directory to write the file into.
:return: Path to the saved ``.npz`` file.
"""
npz_path = local_dir / "input.npz"
np.savez_compressed(npz_path, **input_data)
return npz_path
def _load_output_dir(output_dir: Path) -> dict[str, np.ndarray]:
"""Load inference output from a directory.
Supports two formats produced by ``embedl-onnxruntime``:
* **``.npz`` mode** (``--use-npz``): one or more compressed
``.npz`` files, each containing named arrays keyed by output
tensor name. All arrays from every ``.npz`` file are merged
into a single dict.
* **``.npy`` mode** (default): one ``.npy`` file per output,
keyed by file stem.
:param output_dir: Directory containing the output files.
:return: Mapping of output names to NumPy arrays.
:raises InvokeError: If no output files are found.
"""
# Prefer .npz files (written when --use-npz is passed)
npz_files = sorted(output_dir.glob("*.npz"))
if npz_files:
result: dict[str, np.ndarray] = {}
for npz_file in npz_files:
with np.load(npz_file) as data:
result.update(data)
return result
# Fall back to individual .npy files
npy_files = sorted(output_dir.glob("*.npy"))
if npy_files:
return {f.stem: np.load(f) for f in npy_files}
raise InvokeError(
f"No output files (.npz or .npy) found in '{output_dir}'."
)
[docs]
@dataclass(frozen=True)
class ONNXRuntimeInvocationResult(ComponentOutput):
"""Output from the ONNXRuntimeInvoker component.
:param output: The output data from the model invocation, mapping
output tensor names to NumPy arrays.
:param output_file: The artifact containing the serialised output
``.npz`` file.
"""
output: dict[str, np.ndarray]
output_file: LoggedArtifact
[docs]
class ONNXRuntimeInvoker(Component):
"""Component that runs inference on ONNX models using ``embedl-onnxruntime``.
Runs ``embedl-onnxruntime run-inference`` on a remote device over SSH
to execute inference on an ONNX model and exports the output tensors.
Device-specific parameters (``embedl_onnxruntime_path``,
``cli_args``) are configured via
:class:`~embedl_hub._internal.core.device.config.embedl_onnxruntime.EmbedlONNXRuntimeConfig`
on the device.
"""
run_type = RunType.INFERENCE
def __init__(
self,
*,
name: str | None = None,
device: str | None = None,
use_compiled_names: bool = False,
) -> None:
"""Create a new ONNX Runtime invoker.
Parameters passed here become the defaults for every call.
They can be overridden per call by passing them to
:meth:`run` directly.
:param name: Display name for this component instance.
:param device: Name of the device to use from the
context. When ``None``, auto-selects the first
compatible device.
:param use_compiled_names: If ``False`` (default) and the
compiler renamed input/output tensors, both original
and compiled names are accepted for inputs, and outputs
are returned with original names. Set to ``True`` to
skip all remapping and only accept compiled names.
*Only used on* ``qai_hub`` *devices.*
"""
super().__init__(
name=name,
device=device,
use_compiled_names=use_compiled_names,
)
[docs]
def run(
self,
ctx: HubContext,
model: ONNXRuntimeCompiledModel,
input_data: dict[str, np.ndarray],
*,
device: str | None = None,
use_compiled_names: bool = False,
) -> ONNXRuntimeInvocationResult:
"""Run inference on an ONNX model via ``embedl-onnxruntime``.
Keyword arguments override the defaults set in the constructor.
If a keyword argument is not provided here, the value from the
constructor is used.
:param ctx: The execution context with device configuration.
:param model: An :class:`ONNXRuntimeCompiledModel` whose ``path``
artifact points to an ONNX model.
:param input_data: Dictionary mapping input tensor names to
NumPy arrays. Uploaded as a compressed ``.npz`` file.
:param use_compiled_names: If ``False`` (default) and the
compiler renamed input/output tensors, both original
and compiled names are accepted for inputs, and outputs
are returned with original names. Set to ``True`` to
skip all remapping and only accept compiled names.
*Only used on* ``qai_hub`` *devices.*
:return: An :class:`ONNXRuntimeInvocationResult` with the output
artifact.
"""
raise NoProviderError # Replaced by the component system.
# --------------------------------------------------------------------- #
# embedl_onnxruntime provider
# --------------------------------------------------------------------- #
@ONNXRuntimeInvoker.provider(ProviderType.EMBEDL_ONNXRUNTIME)
def _invoke_ort_cli(
ctx: HubContext,
model: ONNXRuntimeCompiledModel,
input_data: dict[str, np.ndarray],
*,
device: str | None = None,
use_compiled_names: bool = False,
) -> ONNXRuntimeInvocationResult:
"""Run inference on an ONNX model via ``embedl-onnxruntime``.
All keyword arguments are pre-resolved by the component system.
:param ctx: The execution context with device configuration.
:param model: An :class:`ONNXRuntimeCompiledModel`.
:param input_data: Mapping of input names to NumPy arrays.
:param device: Name of the target device.
:param use_compiled_names: Accepted for API consistency but has
no effect — the ``embedl-onnxruntime`` provider does not
rename tensors.
:return: An :class:`ONNXRuntimeInvocationResult`.
"""
# -- validate device ------------------------------------------------
assert device is not None
dev = ctx.devices[device]
runner = dev.runner
if runner is None:
raise ValueError(
"embedl_onnxruntime devices require a device with a "
"command runner."
)
remote_artifact_dir = dev.artifact_dir
if remote_artifact_dir is None:
raise ValueError("No remote artifact directory available.")
# -- resolve model path (auto-transfer) ----------------------------
remote_model = model.path.to_remote(ctx, device_name=device).remote()
# -- resolve provider config ----------------------------------------
cfg = dev.get_provider_config(EmbedlONNXRuntimeConfig, ONNXRuntimeInvoker)
if cfg is None:
cfg = EmbedlONNXRuntimeConfig()
ort_path = str(cfg.embedl_onnxruntime_path)
effective_cli_args: list[str] = list(cfg.cli_args)
# -- log input artifact ---------------------------------------------
ctx.client.log_artifact(
model.path,
name="input",
file_name=f"input_{model.path.file_name}",
ctx=ctx,
)
# -- invoke --------------------------------------------------------
output_data = _run_remote_invocation(
ctx,
runner=runner,
remote_model=remote_model,
input_data=input_data,
remote_artifact_dir=remote_artifact_dir,
ort_path=ort_path,
cli_args=effective_cli_args,
)
# -- save and log output artifact -----------------------------------
output_npz_path = ctx.artifact_dir / "output.npz"
np.savez_compressed(output_npz_path, **output_data)
ctx.client.log_artifact(output_npz_path, name="output_file")
return ONNXRuntimeInvocationResult.from_current_run(
ctx, output=output_data
)
# --------------------------------------------------------------------- #
# QAI Hub provider
# --------------------------------------------------------------------- #
@ONNXRuntimeInvoker.provider(ProviderType.QAI_HUB)
def _invoke_ort_qai_hub(
ctx: HubContext,
model: ONNXRuntimeCompiledModel,
input_data: dict[str, np.ndarray],
*,
device: str | None = None,
use_compiled_names: bool = False,
) -> ONNXRuntimeInvocationResult:
"""Run inference on an ONNX model using Qualcomm AI Hub.
Submits an inference job to QAI Hub and downloads the output data.
The ``cli_args`` parameter is ignored — QAI Hub controls the
inference configuration.
All keyword arguments are pre-resolved by the component system.
:param ctx: The execution context with device configuration.
:param model: An :class:`ONNXRuntimeCompiledModel`.
:param input_data: Mapping of input names to NumPy arrays.
:param device: Name of the target device.
:param use_compiled_names: If ``False`` (default), input names are
remapped to the compiled model's names before submission and
output names are mapped back to the originals.
:return: An :class:`ONNXRuntimeInvocationResult`.
"""
# -- resolve model path (auto-transfer to local) -------------------
model_local_path = model.path.to_local(ctx).local()
input_name_mapping = model.input_name_mapping
output_name_mapping = model.output_name_mapping
# -- validate device ------------------------------------------------
hub_device, device_name = resolve_qai_hub_device(ctx, device)
# -- remap input names if needed ------------------------------------
if not use_compiled_names and input_name_mapping:
input_data = _remap_input_names(input_data, input_name_mapping)
# -- log input artifact ---------------------------------------------
ctx.client.log_artifact(
model.path,
name="input",
file_name=f"input_{model.path.file_name}",
ctx=ctx,
)
# -- invoke via QAI Hub --------------------------------------------
output_data = _invoke_via_qai_hub(
ctx, model_local_path, hub_device, input_data
)
# -- remap output names back to originals ---------------------------
if not use_compiled_names and output_name_mapping:
reverse_output = {v: k for k, v in output_name_mapping.items()}
output_data = {
reverse_output.get(name, name): array
for name, array in output_data.items()
}
# -- save and log output artifact -----------------------------------
output_npz_path = ctx.artifact_dir / "output.npz"
np.savez_compressed(output_npz_path, **output_data)
ctx.client.log_artifact(output_npz_path, name="output_file")
return ONNXRuntimeInvocationResult.from_current_run(
ctx, output=output_data
)
# -- QAI Hub helpers -------------------------------------------------------
def _remap_input_names(
input_data: dict[str, np.ndarray],
name_mapping: dict[str, str],
) -> dict[str, np.ndarray]:
"""Remap input tensor names from original to compiled names.
:param input_data: Original input data keyed by pre-compile names.
:param name_mapping: Mapping from original to compiled names.
:return: New dict with keys replaced according to `name_mapping`.
:raises InvokeError: If an input name is not found in the mapping
and does not match any compiled name either.
"""
compiled_names = set(name_mapping.values())
remapped: dict[str, np.ndarray] = {}
for name, array in input_data.items():
if name in name_mapping:
remapped[name_mapping[name]] = array
elif name in compiled_names:
# Already using compiled names — pass through.
remapped[name] = array
else:
raise InvokeError(
f"Input name '{name}' not found in the name mapping "
f"and is not a compiled model input name. Known "
f"original names: {sorted(name_mapping.keys())}. "
f"Compiled names: {sorted(compiled_names)}."
)
return remapped
def _invoke_via_qai_hub(
ctx: HubContext,
model_path: Path,
hub_device: hub.Device,
input_data: dict[str, np.ndarray],
) -> dict[str, np.ndarray]:
"""Submit an inference job to Qualcomm AI Hub and download results.
:param ctx: The execution context.
:param model_path: Local path to the ONNX model.
:param hub_device: Target QAI Hub device.
:param input_data: Mapping of input names to NumPy arrays.
:return: Mapping of output names to NumPy arrays.
:raises InvokeError: If the inference job fails.
"""
# QAI Hub expects inputs as dict[str, list[np.ndarray]]
hub_inputs: dict[str, list[np.ndarray]] = {
name: [array] for name, array in input_data.items()
}
# Package directory models into a single .onnx file if needed
with TemporaryDirectory() as tmpdir:
packaged_model = maybe_package_onnx_folder_to_file(model_path, tmpdir)
# Submit inference job
try:
job: hub.InferenceJob = hub.submit_inference_job(
model=packaged_model,
device=hub_device,
inputs=hub_inputs,
)
except Exception as error:
raise InvokeError(
f"Failed to submit inference job to Qualcomm AI Hub: {error}"
) from error
ctx.client.log_param("$qai_hub_job_id", job.job_id)
ctx.client.log_param("qai_hub_inference_job_id", job.job_id)
# Download inference output (blocks until the job finishes)
try:
output_dataset = job.download_output_data()
except Exception as error:
raise InvokeError(
f"Failed to download inference output from Qualcomm AI Hub: {error}"
) from error
if output_dataset is None:
raise InvokeError(
f"Inference job returned no output data. Job ID: {job.job_id}"
)
# Parse and log runtime info (safe now — the job has finished)
log_runtime_info(ctx.client, job.job_id, error_class=InvokeError)
# download_output_data() returns Mapping[str, list[np.ndarray]].
# Flatten to dict[str, np.ndarray] by taking the first (and only)
# sample from each output, matching the single-sample input we sent.
output_data: dict[str, np.ndarray] = {}
for name, arrays in output_dataset.items():
if isinstance(arrays, list) and len(arrays) > 0:
output_data[name] = arrays[0]
else:
output_data[name] = np.asarray(arrays)
return output_data
# --------------------------------------------------------------------- #
# internal helpers
# --------------------------------------------------------------------- #
def _run_remote_invocation(
ctx: HubContext,
*,
runner: CommandRunner,
remote_model: RemotePath,
input_data: dict[str, np.ndarray],
remote_artifact_dir: RemotePath,
ort_path: str,
cli_args: list[str],
) -> dict[str, np.ndarray]:
"""Run ``embedl-onnxruntime run-inference`` on a remote device and
download + parse the output.
:param ctx: The execution context.
:param runner: The device command runner.
:param remote_model: Remote path to the ONNX model on the device.
:param input_data: Mapping of input names to NumPy arrays.
:param remote_artifact_dir: Remote directory for artifacts.
:param ort_path: Path to ``embedl-onnxruntime`` on the remote device.
:param cli_args: Extra CLI arguments.
:return: Mapping of output names to NumPy arrays.
:raises InvokeError: If the remote command fails.
"""
local_artifact_dir = ctx.artifact_dir
# Save and upload input data as .npz
local_input_npz = _save_input_npz(input_data, local_artifact_dir)
remote_input_dir = remote_artifact_dir / "input"
runner.run(["mkdir", "-p", str(remote_input_dir)], hide=True)
runner.put(LocalPath(local_input_npz), remote_input_dir / "data.npz")
# Build remote output directory
remote_output_dir = remote_artifact_dir / "output"
runner.run(["mkdir", "-p", str(remote_output_dir)], hide=True)
# Assemble the embedl-onnxruntime command
run_args: list[str] = [
ort_path,
"run-inference",
"--model",
str(remote_model),
"--input-dir",
str(remote_input_dir),
"--output-dir",
str(remote_output_dir),
"--use-npz",
]
# Append extra CLI arguments
run_args.extend(cli_args)
# Log parameters
ctx.client.log_param("$runtime", "onnx")
ctx.client.log_param("embedl_onnxruntime_path", ort_path)
if cli_args:
ctx.client.log_param("cli_args", " ".join(cli_args))
# Execute on the remote device
result = run_remote_command(
runner,
run_args,
local_artifact_dir,
error_cls=InvokeError,
error_message="ONNX Runtime invocation failed on the remote device.",
)
# Download the output directory from the remote device
local_output_dir = local_artifact_dir / "output"
get_directory(runner, remote_output_dir, local_output_dir)
# Parse and return the output data
return _load_output_dir(local_output_dir)