# Copyright (C) 2026 Embedl AB
"""ONNX Runtime compiler component.
Supports compilation via:
- **``qai_hub``** devices: Qualcomm AI Hub cloud compilation.
- **``embedl_onnxruntime``** devices: Remote compilation using
``embedl-onnxruntime`` over SSH.
"""
from __future__ import annotations
import re
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import numpy as np
import qai_hub as hub
from qai_hub.client import CompileJob, QuantizeJob
from embedl_hub._internal.core.compile.result import CompileError
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import CompiledModel
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.device.abc import CommandResult, CommandRunner
from embedl_hub._internal.core.device.commands import run_remote_command
from embedl_hub._internal.core.device.config.embedl_onnxruntime import (
EmbedlONNXRuntimeConfig,
)
from embedl_hub._internal.core.device.transfer import put_directory
from embedl_hub._internal.core.types import LocalPath, RemotePath
from embedl_hub._internal.core.utils.calibration_data import (
generate_random_calibration_data,
load_onnx_calibration_data,
)
from embedl_hub._internal.core.utils.onnx_utils import load_onnx_model
from embedl_hub._internal.core.utils.qai_hub_utils import (
log_runtime_info,
resolve_qai_hub_device,
save_qai_hub_model,
)
from embedl_hub._internal.tracking.rest_api import RunType
CalibrationMethod = Literal["minmax", "entropy"]
"""Supported calibration algorithms for static quantization."""
CalibrationData = Path | dict[str, np.ndarray]
"""Calibration data for post-training quantization.
Can be supplied in two forms:
* **Path** – a directory of ``.npy`` files. For single-input models
the files sit directly in the directory. For multi-input models
the directory must contain one subdirectory per model input, named
after that input, each holding ``.npy`` files.
* **dict** – a mapping whose keys are the model input names and whose
values are NumPy arrays with shape
``(num_samples, *input_shape)``. The leading axis indexes the
calibration samples; each ``array[i]`` must already have the shape
the model expects for that input (including any batch dimension).
"""
def _save_calibration_dict(
data: dict[str, np.ndarray],
target_dir: Path,
) -> Path:
"""Persist a calibration-data dictionary as ``.npy`` files on disk.
Each value array has shape ``(num_samples, *input_shape)``. The
leading axis indexes the calibration samples; each ``array[i]``
is saved as its own ``.npy`` file.
The resulting directory layout is compatible with
:func:`~embedl_hub._internal.core.utils.calibration_data.load_calibration_dataset`:
* **Single input** (one key) – the individual sample files
(``0.npy``, ``1.npy``, …) are saved directly inside
*target_dir*.
* **Multiple inputs** – a subdirectory is created for each key
(i.e. each model input name), and the sample files are written
inside it.
The dictionary keys **must** match the input names of the ONNX
model so that the calibration samples are mapped to the correct
model inputs when loaded back.
:param data: Mapping of model input names to calibration arrays.
Each array must have shape ``(num_samples, *input_shape)``.
:param target_dir: Directory in which to write the files. Will be
created if it does not exist.
:return: *target_dir* (for convenient chaining).
"""
target_dir.mkdir(parents=True, exist_ok=True)
if len(data) == 1:
array = next(iter(data.values()))
for idx, sample in enumerate(array):
np.save(target_dir / f"{idx}.npy", sample)
else:
for name, array in data.items():
input_dir = target_dir / name
input_dir.mkdir(parents=True, exist_ok=True)
for idx, sample in enumerate(array):
np.save(input_dir / f"{idx}.npy", sample)
return target_dir
[docs]
@dataclass(frozen=True)
class ONNXRuntimeCompiledModel(CompiledModel):
"""Output from the :class:`ONNXRuntimeCompiler` component.
Extends :class:`~embedl_hub._internal.core.component.output.CompiledModel`
with optional input/output tensor name mappings that record how
the compiler renamed tensors.
"""
input_name_mapping: dict[str, str] | None = None
output_name_mapping: dict[str, str] | None = None
[docs]
class ONNXRuntimeCompiler(Component):
"""Component that compiles ONNX models to ONNX Runtime format.
Supports two device types:
- **``qai_hub``** devices: compile via the Qualcomm AI Hub cloud
service.
- **``embedl_onnxruntime``** devices: compile via the
``embedl-onnxruntime`` CLI on a remote device over SSH.
"""
run_type = RunType.COMPILE
def __init__(
self,
*,
name: str | None = None,
device: str | None = None,
input_shape: tuple[int, ...] | None = None,
calibration_data: CalibrationData | None = None,
calibration_method: CalibrationMethod | None = None,
per_channel: bool = False,
quantize_io: bool = False,
) -> None:
"""Create a new ONNX Runtime compiler.
Parameters passed here become the defaults for every
:meth:`run` call. They can be overridden per call by passing
them to :meth:`run` directly::
compiler = ONNXRuntimeCompiler(quantize_io=True)
compiler.run(ctx, model) # uses quantize_io=True
compiler.run(ctx, model, quantize_io=False) # overrides to False
:param name: Display name for this compiler instance. Defaults
to the class name.
:param device: Name of the device to use from the context.
When ``None``, auto-selects the first compatible device.
:param input_shape: Input shape tuple, e.g.
``(1, 3, 224, 224)``. *Only used on* ``qai_hub``
*devices.*
:param calibration_data: Calibration data for post-training
quantization — a path to a directory of ``.npy`` files
or a ``dict`` mapping model input names to NumPy arrays.
:param calibration_method: Calibration algorithm (``'minmax'``
or ``'entropy'``). *Only used on*
``embedl_onnxruntime`` *devices.*
:param per_channel: If ``True``, quantize weights per channel.
*Only used on* ``embedl_onnxruntime`` *devices.*
:param quantize_io: If ``True``, quantize model I/O tensors
to INT8. *Only used on* ``qai_hub`` *devices.*
"""
super().__init__(
name=name,
device=device,
input_shape=input_shape,
calibration_data=calibration_data,
calibration_method=calibration_method,
per_channel=per_channel,
quantize_io=quantize_io,
)
[docs]
def run(
self,
ctx: HubContext,
onnx_path: Path,
*,
device: str | None = None,
input_shape: tuple[int, ...] | None = None,
calibration_data: CalibrationData | None = None,
calibration_method: CalibrationMethod | None = None,
per_channel: bool = False,
quantize_io: bool = False,
) -> ONNXRuntimeCompiledModel:
"""Compile an ONNX model to ONNX Runtime format.
Keyword arguments override the defaults set in the constructor.
If a keyword argument is not provided here, the constructor
default is used.
:param ctx: The execution context with device configuration.
:param onnx_path: Path to the input ONNX model.
:param device: Name of the target device (overrides the
constructor default).
:param input_shape: Input shape tuple, e.g.
``(1, 3, 224, 224)``. *Only used on* ``qai_hub``
*devices.*
:param calibration_data: Calibration data for post-training
quantization — a path to a directory of ``.npy`` files
or a ``dict`` mapping model input names to NumPy arrays.
:param calibration_method: Calibration algorithm (``'minmax'``
or ``'entropy'``). *Only used on*
``embedl_onnxruntime`` *devices.*
:param per_channel: If ``True``, quantize weights per channel.
*Only used on* ``embedl_onnxruntime`` *devices.*
:param quantize_io: If ``True``, quantize model I/O tensors
to INT8. *Only used on* ``qai_hub`` *devices.*
:return: An :class:`ONNXRuntimeCompiledModel` with the path to
the compiled model.
"""
raise NoProviderError # Replaced by the component system.
# --------------------------------------------------------------------- #
# QAI Hub provider
# --------------------------------------------------------------------- #
@ONNXRuntimeCompiler.provider(ProviderType.QAI_HUB)
def _compile_ort_qai_hub(
ctx: HubContext,
onnx_path: Path,
*,
device: str | None = None,
input_shape: tuple[int, ...] | None = None,
calibration_data: CalibrationData | None = None,
calibration_method: CalibrationMethod | None = None,
per_channel: bool = False,
quantize_io: bool = False,
) -> ONNXRuntimeCompiledModel:
"""Compile an ONNX model to ONNX Runtime using Qualcomm AI Hub.
The model is first quantized to INT8 via ``hub.submit_quantize_job()``,
then compiled for ONNX Runtime via ``hub.submit_compile_job()``.
All keyword arguments are pre-resolved by the component system.
:param ctx: The execution context with device configuration.
:param onnx_path: Path to the input ONNX model.
:param input_shape: Input shape tuple.
:param calibration_data: Calibration data for quantization.
:param calibration_method: Calibration algorithm (ignored on qai_hub).
:param per_channel: Per-channel quantization (ignored on qai_hub).
:param quantize_io: If ``True``, quantize I/O tensors.
:return: An :class:`ONNXRuntimeCompiledModel` with the path to the
compiled model.
"""
if calibration_method is not None:
warnings.warn(
"'calibration_method' has no effect on qai_hub devices "
"and will be ignored.",
stacklevel=2,
)
if per_channel:
warnings.warn(
"'per_channel' has no effect on qai_hub devices "
"and will be ignored.",
stacklevel=2,
)
# Log the input model artifact
ctx.client.log_artifact(
onnx_path, name="input", file_name=f"input_{onnx_path.name}"
)
hub_device, device_name = resolve_qai_hub_device(ctx, device)
if input_shape:
ctx.client.log_param("input_shape", "x".join(map(str, input_shape)))
# Step 1: Quantize
quantized_model_path = _quantize_via_qai_hub(
ctx, onnx_path, calibration_data
)
# Step 2: Compile
compiled_model_path = _compile_via_qai_hub(
ctx,
quantized_model_path,
hub_device,
input_shape=input_shape,
quantize_io=quantize_io,
output_name=f"compiled_{onnx_path.name}",
)
# Log the compiled model artifact
ctx.client.log_artifact(compiled_model_path, name="path")
# Build input/output name mappings by comparing the
# pre-compile (quantized) and post-compile models positionally.
input_name_mapping, output_name_mapping = _build_name_mappings(
quantized_model_path, compiled_model_path
)
if input_name_mapping:
ctx.client.log_param(
"input_name_mapping",
", ".join(f"{k} -> {v}" for k, v in input_name_mapping.items()),
)
if output_name_mapping:
ctx.client.log_param(
"output_name_mapping",
", ".join(f"{k} -> {v}" for k, v in output_name_mapping.items()),
)
return ONNXRuntimeCompiledModel.from_current_run(
ctx,
input_name_mapping=input_name_mapping or None,
output_name_mapping=output_name_mapping or None,
)
# -- QAI Hub helpers -------------------------------------------------------
def _build_name_mappings(
source_model_path: Path,
compiled_model_path: Path,
) -> tuple[dict[str, str], dict[str, str]]:
"""Build input/output name mappings between two ONNX models.
Compares the input and output tensor names of the source and compiled
models **by position**. Only names that actually changed are included
in the returned dictionaries.
:param source_model_path: Path to the pre-compile ONNX model.
:param compiled_model_path: Path to the post-compile ONNX model.
:return: A ``(input_mapping, output_mapping)`` tuple where each dict
maps original names to compiled names. Empty dicts if no names
changed.
:raises CompileError: If the number of inputs or outputs differs
between the two models.
"""
source_model = load_onnx_model(source_model_path)
compiled_model = load_onnx_model(compiled_model_path)
src_inputs = list(source_model.graph.input)
tgt_inputs = list(compiled_model.graph.input)
if len(src_inputs) != len(tgt_inputs):
raise CompileError(
f"Input count mismatch: source model has "
f"{len(src_inputs)} input(s) but compiled model has "
f"{len(tgt_inputs)}. Cannot build name mapping."
)
src_outputs = list(source_model.graph.output)
tgt_outputs = list(compiled_model.graph.output)
if len(src_outputs) != len(tgt_outputs):
raise CompileError(
f"Output count mismatch: source model has "
f"{len(src_outputs)} output(s) but compiled model has "
f"{len(tgt_outputs)}. Cannot build name mapping."
)
input_mapping: dict[str, str] = {}
for src_inp, tgt_inp in zip(src_inputs, tgt_inputs):
if src_inp.name != tgt_inp.name:
input_mapping[src_inp.name] = tgt_inp.name
output_mapping: dict[str, str] = {}
for src_out, tgt_out in zip(src_outputs, tgt_outputs):
if src_out.name != tgt_out.name:
output_mapping[src_out.name] = tgt_out.name
return input_mapping, output_mapping
def _quantize_via_qai_hub(
ctx: HubContext,
onnx_path: Path,
calibration_data: CalibrationData | None,
) -> Path:
"""Submit a quantization job to Qualcomm AI Hub.
:param ctx: The execution context.
:param onnx_path: Path to the floating-point ONNX model.
:param calibration_data: Optional calibration data.
:return: Local path to the quantized ONNX model.
:raises CompileError: If the quantization job fails.
"""
# Prepare calibration dataset
if isinstance(calibration_data, dict):
calib_dataset: dict[str, list[np.ndarray]] = {
name: list(array) for name, array in calibration_data.items()
}
ctx.client.log_param(
"num_calibration_samples",
str(min(len(v) for v in calib_dataset.values())),
)
elif calibration_data is not None:
calib_dataset = load_onnx_calibration_data(
model_path=onnx_path,
data_path=calibration_data,
)
ctx.client.log_param(
"num_calibration_samples",
str(min(len(v) for v in calib_dataset.values())),
)
else:
calib_dataset = generate_random_calibration_data(onnx_path)
ctx.client.log_param("num_calibration_samples", "1")
# Submit quantization job
try:
quantize_job: QuantizeJob = hub.submit_quantize_job(
model=onnx_path.as_posix(),
calibration_data=calib_dataset,
weights_dtype=hub.QuantizeDtype.INT8,
activations_dtype=hub.QuantizeDtype.INT8,
)
except Exception as error:
raise CompileError("Failed to submit quantization job.") from error
ctx.client.log_param("qai_hub_quantize_job_id", quantize_job.job_id)
# Download quantized model
try:
quantized_model = quantize_job.get_target_model()
except Exception as error:
raise CompileError(
"Failed to download quantized model from Qualcomm AI Hub."
) from error
if quantized_model is None:
raise CompileError("Quantization job did not produce a target model.")
quantized_path = save_qai_hub_model(
quantized_model,
ctx.artifact_dir / f"quantized_{onnx_path.name}",
)
ctx.client.log_artifact(
quantized_path,
name="quantized",
file_name=f"quantized_{onnx_path.name}",
)
return quantized_path
def _compile_via_qai_hub(
ctx: HubContext,
model_path: Path,
hub_device: hub.Device,
*,
input_shape: tuple[int, ...] | None,
quantize_io: bool,
output_name: str = "compiled_model.onnx",
) -> Path:
"""Submit a compile job to Qualcomm AI Hub.
:param ctx: The execution context.
:param model_path: Path to the (quantized) ONNX model.
:param hub_device: Target QAI Hub device.
:param input_shape: Optional input shape override.
:param quantize_io: Whether to quantize I/O tensors.
:param output_name: Filename for the compiled model artifact.
:return: Local path to the compiled ONNX model.
:raises CompileError: If the compile job fails.
"""
# Build compile options
opts = "--target_runtime onnx"
if quantize_io:
opts += " --quantize_io"
input_specs = {"image": input_shape} if input_shape else None
# Submit compile job
try:
compile_job: CompileJob = hub.submit_compile_job(
model=model_path.as_posix(),
device=hub_device,
options=opts,
input_specs=input_specs,
)
except Exception as error:
raise CompileError("Failed to submit compile job.") from error
ctx.client.log_param("$qai_hub_job_id", compile_job.job_id)
ctx.client.log_param("qai_hub_compile_job_id", compile_job.job_id)
# Download compiled model
try:
compiled_model = compile_job.get_target_model()
except Exception as error:
raise CompileError(
"Failed to download compiled model from Qualcomm AI Hub."
) from error
if compiled_model is None:
raise CompileError("Compile job did not produce a target model.")
compiled_model_path = save_qai_hub_model(
compiled_model, ctx.artifact_dir / output_name
)
# Parse and log runtime info
log_runtime_info(ctx.client, compile_job.job_id, error_class=CompileError)
return compiled_model_path
# --------------------------------------------------------------------- #
# embedl_onnxruntime provider
# --------------------------------------------------------------------- #
@ONNXRuntimeCompiler.provider(ProviderType.EMBEDL_ONNXRUNTIME)
def _compile_ort_cli(
ctx: HubContext,
onnx_path: Path,
*,
device: str | None = None,
input_shape: tuple[int, ...] | None = None,
calibration_data: CalibrationData | None = None,
calibration_method: CalibrationMethod | None = None,
per_channel: bool = False,
quantize_io: bool = False,
) -> ONNXRuntimeCompiledModel:
"""Compile an ONNX model via ``embedl-onnxruntime`` on a remote device.
All keyword arguments are pre-resolved by the component system.
:param ctx: The execution context with device configuration.
:param onnx_path: Local path to the input ONNX model.
:param input_shape: Input shape tuple.
:param calibration_data: Calibration data for quantization.
:param calibration_method: Calibration algorithm.
:param per_channel: Per-channel quantization flag.
:param quantize_io: I/O quantization flag (ignored on
``embedl-onnxruntime`` devices).
:return: :class:`ONNXRuntimeCompiledModel` with the compiled model path.
"""
# Convert dict calibration data to disk
if isinstance(calibration_data, dict):
calibration_data = _save_calibration_dict(
calibration_data, ctx.artifact_dir / "calibration_data"
)
if calibration_method is not None and calibration_data is None:
warnings.warn(
"'calibration_method' has no effect without "
"'calibration_data'. Random calibration will be used.",
stacklevel=2,
)
if quantize_io:
warnings.warn(
"'quantize_io' has no effect on embedl-onnxruntime devices "
"and will be ignored.",
stacklevel=2,
)
assert device is not None
dev = ctx.devices[device]
runner = dev.runner
if runner is None:
raise ValueError(
"embedl_onnxruntime devices require a device with a "
"command runner."
)
remote_artifact_dir = dev.artifact_dir
if remote_artifact_dir is None:
raise ValueError("No remote artifact directory available.")
# Log the input model artifact
ctx.client.log_artifact(
onnx_path, name="input", file_name=f"input_{onnx_path.name}"
)
# Log runtime parameter
ctx.client.log_param("$runtime", "onnx")
# Step 1: Run quantization on the remote device
result = _run_remote_quantization(
ctx,
runner=runner,
onnx_path=onnx_path,
remote_artifact_dir=remote_artifact_dir,
calibration_data=calibration_data,
calibration_method=calibration_method,
per_channel=per_channel,
)
# Step 2: Retrieve the quantized model
local_model_path = _retrieve_remote_model(
ctx,
runner=runner,
result=result,
remote_artifact_dir=remote_artifact_dir,
local_artifact_dir=ctx.artifact_dir,
output_name=f"compiled_{onnx_path.name}",
)
# Log the compiled model artifact
ctx.client.log_artifact(local_model_path, name="path")
return ONNXRuntimeCompiledModel.from_current_run(ctx)
# -- CLI helpers -----------------------------------------------------------
def _run_remote_quantization(
ctx: HubContext,
*,
runner: CommandRunner,
onnx_path: Path,
remote_artifact_dir: RemotePath,
calibration_data: Path | None,
calibration_method: CalibrationMethod | None,
per_channel: bool,
) -> CommandResult:
"""Upload the model and run ``embedl-onnxruntime`` on the device.
:param ctx: The execution context.
:param runner: The device command runner.
:param onnx_path: Local path to the input ONNX model.
:param remote_artifact_dir: Remote directory for artifacts.
:param calibration_data: Optional local path to calibration data.
:param calibration_method: Calibration algorithm to use.
:param per_channel: Whether to use per-channel quantization.
:return: The process result from the remote command.
:raises CompileError: If the remote command fails.
"""
# Upload the ONNX model to the remote device
remote_onnx = remote_artifact_dir / onnx_path.name
runner.put(LocalPath(onnx_path), remote_onnx)
# Ensure model path points to an .onnx file
model_arg = remote_onnx
if model_arg.suffix != ".onnx":
model_arg = model_arg / "model.onnx"
# Resolve tool path and extra CLI args from device config
dev = next(iter(ctx.devices.values()))
cfg = dev.get_provider_config(EmbedlONNXRuntimeConfig, ONNXRuntimeCompiler)
if cfg is None:
cfg = EmbedlONNXRuntimeConfig()
ort_path = str(cfg.embedl_onnxruntime_path)
cli_args = list(cfg.cli_args)
# Determine calibration method
calib_method: str
if calibration_data is not None:
calib_method = calibration_method or "minmax"
else:
calib_method = "random"
# Build the embedl-onnxruntime command
run_args: list[str] = [
ort_path,
"quantize-static",
"--model",
str(model_arg),
"--calib-method",
calib_method,
]
# Upload calibration data to the remote device if provided
remote_calib_dir: RemotePath | None = None
if calibration_data is not None:
remote_calib_dir = remote_artifact_dir / "calibration_data"
put_directory(runner, LocalPath(calibration_data), remote_calib_dir)
run_args.extend(["--calib-dir", str(remote_calib_dir)])
if per_channel:
run_args.append("--per-channel")
# Append any extra CLI arguments from device metadata
run_args.extend(cli_args)
# Log parameters
ctx.client.log_param("embedl_onnxruntime_path", ort_path)
ctx.client.log_param("calib_method", calib_method)
ctx.client.log_param("per_channel", str(per_channel))
if cli_args:
ctx.client.log_param("cli_args", " ".join(cli_args))
if remote_calib_dir is not None:
ctx.client.log_param("calib_dir", str(remote_calib_dir))
# Execute compilation on the remote device
local_artifact_dir = ctx.artifact_dir
result = run_remote_command(
runner,
run_args,
local_artifact_dir,
error_cls=CompileError,
error_message="Model compilation failed on the remote device.",
)
return result
def _retrieve_remote_model(
ctx: HubContext,
*,
runner: CommandRunner,
result: CommandResult,
remote_artifact_dir: RemotePath,
local_artifact_dir: Path,
output_name: str = "compiled_model.onnx",
) -> Path:
"""Parse the quantized model path and download it locally.
:param ctx: The execution context.
:param runner: The device command runner.
:param result: The process result containing stdout.
:param remote_artifact_dir: Remote directory for artifacts.
:param local_artifact_dir: Local directory for artifacts.
:param output_name: Filename for the downloaded model.
:return: Local path to the downloaded quantized model.
:raises CompileError: If the model path cannot be parsed.
"""
match = re.search(
r"Quantized model was saved to (?P<path>.*)\.",
(result.stdout or "").rstrip(),
)
if match is None:
raise CompileError(
"Could not determine quantized model path from "
"embedl-onnxruntime output."
)
quantized_model_path = RemotePath(match["path"])
# Copy quantized model into the remote artifact directory
remote_model_path = remote_artifact_dir / quantized_model_path.name
if str(quantized_model_path) != str(remote_model_path):
runner.run(
["cp", str(quantized_model_path), str(remote_model_path)],
hide=True,
)
# Download the compiled model to the local artifact directory
local_model_path = local_artifact_dir / output_name
runner.get(remote_model_path, local_model_path)
return local_model_path