Source code for embedl_hub._internal.core.compile.onnxruntime

# Copyright (C) 2026 Embedl AB

"""ONNX Runtime compiler component.

Supports compilation via:
- **``qai_hub``** devices: Qualcomm AI Hub cloud compilation.
- **``embedl_onnxruntime``** devices: Remote compilation using
  ``embedl-onnxruntime`` over SSH.
"""

from __future__ import annotations

import re
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import numpy as np
import qai_hub as hub
from qai_hub.client import CompileJob, QuantizeJob

from embedl_hub._internal.core.compile.result import CompileError
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import CompiledModel
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.device.abc import CommandResult, CommandRunner
from embedl_hub._internal.core.device.commands import run_remote_command
from embedl_hub._internal.core.device.config.embedl_onnxruntime import (
    EmbedlONNXRuntimeConfig,
)
from embedl_hub._internal.core.device.transfer import put_directory
from embedl_hub._internal.core.types import LocalPath, RemotePath
from embedl_hub._internal.core.utils.calibration_data import (
    generate_random_calibration_data,
    load_onnx_calibration_data,
)
from embedl_hub._internal.core.utils.onnx_utils import load_onnx_model
from embedl_hub._internal.core.utils.qai_hub_utils import (
    log_runtime_info,
    resolve_qai_hub_device,
    save_qai_hub_model,
)
from embedl_hub._internal.tracking.rest_api import RunType

CalibrationMethod = Literal["minmax", "entropy"]
"""Supported calibration algorithms for static quantization."""

CalibrationData = Path | dict[str, np.ndarray]
"""Calibration data for post-training quantization.

Can be supplied in two forms:

* **Path** – a directory of ``.npy`` files.  For single-input models
  the files sit directly in the directory.  For multi-input models
  the directory must contain one subdirectory per model input, named
  after that input, each holding ``.npy`` files.
* **dict** – a mapping whose keys are the model input names and whose
  values are NumPy arrays with shape
  ``(num_samples, *input_shape)``.  The leading axis indexes the
  calibration samples; each ``array[i]`` must already have the shape
  the model expects for that input (including any batch dimension).
"""


def _save_calibration_dict(
    data: dict[str, np.ndarray],
    target_dir: Path,
) -> Path:
    """Persist a calibration-data dictionary as ``.npy`` files on disk.

    Each value array has shape ``(num_samples, *input_shape)``.  The
    leading axis indexes the calibration samples; each ``array[i]``
    is saved as its own ``.npy`` file.

    The resulting directory layout is compatible with
    :func:`~embedl_hub._internal.core.utils.calibration_data.load_calibration_dataset`:

    * **Single input** (one key) – the individual sample files
      (``0.npy``, ``1.npy``, …) are saved directly inside
      *target_dir*.
    * **Multiple inputs** – a subdirectory is created for each key
      (i.e. each model input name), and the sample files are written
      inside it.

    The dictionary keys **must** match the input names of the ONNX
    model so that the calibration samples are mapped to the correct
    model inputs when loaded back.

    :param data: Mapping of model input names to calibration arrays.
        Each array must have shape ``(num_samples, *input_shape)``.
    :param target_dir: Directory in which to write the files.  Will be
        created if it does not exist.
    :return: *target_dir* (for convenient chaining).
    """
    target_dir.mkdir(parents=True, exist_ok=True)
    if len(data) == 1:
        array = next(iter(data.values()))
        for idx, sample in enumerate(array):
            np.save(target_dir / f"{idx}.npy", sample)
    else:
        for name, array in data.items():
            input_dir = target_dir / name
            input_dir.mkdir(parents=True, exist_ok=True)
            for idx, sample in enumerate(array):
                np.save(input_dir / f"{idx}.npy", sample)
    return target_dir


[docs] @dataclass(frozen=True) class ONNXRuntimeCompiledModel(CompiledModel): """Output from the :class:`ONNXRuntimeCompiler` component. Extends :class:`~embedl_hub._internal.core.component.output.CompiledModel` with optional input/output tensor name mappings that record how the compiler renamed tensors. """ input_name_mapping: dict[str, str] | None = None output_name_mapping: dict[str, str] | None = None
[docs] class ONNXRuntimeCompiler(Component): """Component that compiles ONNX models to ONNX Runtime format. Supports two device types: - **``qai_hub``** devices: compile via the Qualcomm AI Hub cloud service. - **``embedl_onnxruntime``** devices: compile via the ``embedl-onnxruntime`` CLI on a remote device over SSH. """ run_type = RunType.COMPILE def __init__( self, *, name: str | None = None, device: str | None = None, input_shape: tuple[int, ...] | None = None, calibration_data: CalibrationData | None = None, calibration_method: CalibrationMethod | None = None, per_channel: bool = False, quantize_io: bool = False, ) -> None: """Create a new ONNX Runtime compiler. Parameters passed here become the defaults for every :meth:`run` call. They can be overridden per call by passing them to :meth:`run` directly:: compiler = ONNXRuntimeCompiler(quantize_io=True) compiler.run(ctx, model) # uses quantize_io=True compiler.run(ctx, model, quantize_io=False) # overrides to False :param name: Display name for this compiler instance. Defaults to the class name. :param device: Name of the device to use from the context. When ``None``, auto-selects the first compatible device. :param input_shape: Input shape tuple, e.g. ``(1, 3, 224, 224)``. *Only used on* ``qai_hub`` *devices.* :param calibration_data: Calibration data for post-training quantization — a path to a directory of ``.npy`` files or a ``dict`` mapping model input names to NumPy arrays. :param calibration_method: Calibration algorithm (``'minmax'`` or ``'entropy'``). *Only used on* ``embedl_onnxruntime`` *devices.* :param per_channel: If ``True``, quantize weights per channel. *Only used on* ``embedl_onnxruntime`` *devices.* :param quantize_io: If ``True``, quantize model I/O tensors to INT8. *Only used on* ``qai_hub`` *devices.* """ super().__init__( name=name, device=device, input_shape=input_shape, calibration_data=calibration_data, calibration_method=calibration_method, per_channel=per_channel, quantize_io=quantize_io, )
[docs] def run( self, ctx: HubContext, onnx_path: Path, *, device: str | None = None, input_shape: tuple[int, ...] | None = None, calibration_data: CalibrationData | None = None, calibration_method: CalibrationMethod | None = None, per_channel: bool = False, quantize_io: bool = False, ) -> ONNXRuntimeCompiledModel: """Compile an ONNX model to ONNX Runtime format. Keyword arguments override the defaults set in the constructor. If a keyword argument is not provided here, the constructor default is used. :param ctx: The execution context with device configuration. :param onnx_path: Path to the input ONNX model. :param device: Name of the target device (overrides the constructor default). :param input_shape: Input shape tuple, e.g. ``(1, 3, 224, 224)``. *Only used on* ``qai_hub`` *devices.* :param calibration_data: Calibration data for post-training quantization — a path to a directory of ``.npy`` files or a ``dict`` mapping model input names to NumPy arrays. :param calibration_method: Calibration algorithm (``'minmax'`` or ``'entropy'``). *Only used on* ``embedl_onnxruntime`` *devices.* :param per_channel: If ``True``, quantize weights per channel. *Only used on* ``embedl_onnxruntime`` *devices.* :param quantize_io: If ``True``, quantize model I/O tensors to INT8. *Only used on* ``qai_hub`` *devices.* :return: An :class:`ONNXRuntimeCompiledModel` with the path to the compiled model. """ raise NoProviderError # Replaced by the component system.
# --------------------------------------------------------------------- # # QAI Hub provider # --------------------------------------------------------------------- # @ONNXRuntimeCompiler.provider(ProviderType.QAI_HUB) def _compile_ort_qai_hub( ctx: HubContext, onnx_path: Path, *, device: str | None = None, input_shape: tuple[int, ...] | None = None, calibration_data: CalibrationData | None = None, calibration_method: CalibrationMethod | None = None, per_channel: bool = False, quantize_io: bool = False, ) -> ONNXRuntimeCompiledModel: """Compile an ONNX model to ONNX Runtime using Qualcomm AI Hub. The model is first quantized to INT8 via ``hub.submit_quantize_job()``, then compiled for ONNX Runtime via ``hub.submit_compile_job()``. All keyword arguments are pre-resolved by the component system. :param ctx: The execution context with device configuration. :param onnx_path: Path to the input ONNX model. :param input_shape: Input shape tuple. :param calibration_data: Calibration data for quantization. :param calibration_method: Calibration algorithm (ignored on qai_hub). :param per_channel: Per-channel quantization (ignored on qai_hub). :param quantize_io: If ``True``, quantize I/O tensors. :return: An :class:`ONNXRuntimeCompiledModel` with the path to the compiled model. """ if calibration_method is not None: warnings.warn( "'calibration_method' has no effect on qai_hub devices " "and will be ignored.", stacklevel=2, ) if per_channel: warnings.warn( "'per_channel' has no effect on qai_hub devices " "and will be ignored.", stacklevel=2, ) # Log the input model artifact ctx.client.log_artifact( onnx_path, name="input", file_name=f"input_{onnx_path.name}" ) hub_device, device_name = resolve_qai_hub_device(ctx, device) if input_shape: ctx.client.log_param("input_shape", "x".join(map(str, input_shape))) # Step 1: Quantize quantized_model_path = _quantize_via_qai_hub( ctx, onnx_path, calibration_data ) # Step 2: Compile compiled_model_path = _compile_via_qai_hub( ctx, quantized_model_path, hub_device, input_shape=input_shape, quantize_io=quantize_io, output_name=f"compiled_{onnx_path.name}", ) # Log the compiled model artifact ctx.client.log_artifact(compiled_model_path, name="path") # Build input/output name mappings by comparing the # pre-compile (quantized) and post-compile models positionally. input_name_mapping, output_name_mapping = _build_name_mappings( quantized_model_path, compiled_model_path ) if input_name_mapping: ctx.client.log_param( "input_name_mapping", ", ".join(f"{k} -> {v}" for k, v in input_name_mapping.items()), ) if output_name_mapping: ctx.client.log_param( "output_name_mapping", ", ".join(f"{k} -> {v}" for k, v in output_name_mapping.items()), ) return ONNXRuntimeCompiledModel.from_current_run( ctx, input_name_mapping=input_name_mapping or None, output_name_mapping=output_name_mapping or None, ) # -- QAI Hub helpers ------------------------------------------------------- def _build_name_mappings( source_model_path: Path, compiled_model_path: Path, ) -> tuple[dict[str, str], dict[str, str]]: """Build input/output name mappings between two ONNX models. Compares the input and output tensor names of the source and compiled models **by position**. Only names that actually changed are included in the returned dictionaries. :param source_model_path: Path to the pre-compile ONNX model. :param compiled_model_path: Path to the post-compile ONNX model. :return: A ``(input_mapping, output_mapping)`` tuple where each dict maps original names to compiled names. Empty dicts if no names changed. :raises CompileError: If the number of inputs or outputs differs between the two models. """ source_model = load_onnx_model(source_model_path) compiled_model = load_onnx_model(compiled_model_path) src_inputs = list(source_model.graph.input) tgt_inputs = list(compiled_model.graph.input) if len(src_inputs) != len(tgt_inputs): raise CompileError( f"Input count mismatch: source model has " f"{len(src_inputs)} input(s) but compiled model has " f"{len(tgt_inputs)}. Cannot build name mapping." ) src_outputs = list(source_model.graph.output) tgt_outputs = list(compiled_model.graph.output) if len(src_outputs) != len(tgt_outputs): raise CompileError( f"Output count mismatch: source model has " f"{len(src_outputs)} output(s) but compiled model has " f"{len(tgt_outputs)}. Cannot build name mapping." ) input_mapping: dict[str, str] = {} for src_inp, tgt_inp in zip(src_inputs, tgt_inputs): if src_inp.name != tgt_inp.name: input_mapping[src_inp.name] = tgt_inp.name output_mapping: dict[str, str] = {} for src_out, tgt_out in zip(src_outputs, tgt_outputs): if src_out.name != tgt_out.name: output_mapping[src_out.name] = tgt_out.name return input_mapping, output_mapping def _quantize_via_qai_hub( ctx: HubContext, onnx_path: Path, calibration_data: CalibrationData | None, ) -> Path: """Submit a quantization job to Qualcomm AI Hub. :param ctx: The execution context. :param onnx_path: Path to the floating-point ONNX model. :param calibration_data: Optional calibration data. :return: Local path to the quantized ONNX model. :raises CompileError: If the quantization job fails. """ # Prepare calibration dataset if isinstance(calibration_data, dict): calib_dataset: dict[str, list[np.ndarray]] = { name: list(array) for name, array in calibration_data.items() } ctx.client.log_param( "num_calibration_samples", str(min(len(v) for v in calib_dataset.values())), ) elif calibration_data is not None: calib_dataset = load_onnx_calibration_data( model_path=onnx_path, data_path=calibration_data, ) ctx.client.log_param( "num_calibration_samples", str(min(len(v) for v in calib_dataset.values())), ) else: calib_dataset = generate_random_calibration_data(onnx_path) ctx.client.log_param("num_calibration_samples", "1") # Submit quantization job try: quantize_job: QuantizeJob = hub.submit_quantize_job( model=onnx_path.as_posix(), calibration_data=calib_dataset, weights_dtype=hub.QuantizeDtype.INT8, activations_dtype=hub.QuantizeDtype.INT8, ) except Exception as error: raise CompileError("Failed to submit quantization job.") from error ctx.client.log_param("qai_hub_quantize_job_id", quantize_job.job_id) # Download quantized model try: quantized_model = quantize_job.get_target_model() except Exception as error: raise CompileError( "Failed to download quantized model from Qualcomm AI Hub." ) from error if quantized_model is None: raise CompileError("Quantization job did not produce a target model.") quantized_path = save_qai_hub_model( quantized_model, ctx.artifact_dir / f"quantized_{onnx_path.name}", ) ctx.client.log_artifact( quantized_path, name="quantized", file_name=f"quantized_{onnx_path.name}", ) return quantized_path def _compile_via_qai_hub( ctx: HubContext, model_path: Path, hub_device: hub.Device, *, input_shape: tuple[int, ...] | None, quantize_io: bool, output_name: str = "compiled_model.onnx", ) -> Path: """Submit a compile job to Qualcomm AI Hub. :param ctx: The execution context. :param model_path: Path to the (quantized) ONNX model. :param hub_device: Target QAI Hub device. :param input_shape: Optional input shape override. :param quantize_io: Whether to quantize I/O tensors. :param output_name: Filename for the compiled model artifact. :return: Local path to the compiled ONNX model. :raises CompileError: If the compile job fails. """ # Build compile options opts = "--target_runtime onnx" if quantize_io: opts += " --quantize_io" input_specs = {"image": input_shape} if input_shape else None # Submit compile job try: compile_job: CompileJob = hub.submit_compile_job( model=model_path.as_posix(), device=hub_device, options=opts, input_specs=input_specs, ) except Exception as error: raise CompileError("Failed to submit compile job.") from error ctx.client.log_param("$qai_hub_job_id", compile_job.job_id) ctx.client.log_param("qai_hub_compile_job_id", compile_job.job_id) # Download compiled model try: compiled_model = compile_job.get_target_model() except Exception as error: raise CompileError( "Failed to download compiled model from Qualcomm AI Hub." ) from error if compiled_model is None: raise CompileError("Compile job did not produce a target model.") compiled_model_path = save_qai_hub_model( compiled_model, ctx.artifact_dir / output_name ) # Parse and log runtime info log_runtime_info(ctx.client, compile_job.job_id, error_class=CompileError) return compiled_model_path # --------------------------------------------------------------------- # # embedl_onnxruntime provider # --------------------------------------------------------------------- # @ONNXRuntimeCompiler.provider(ProviderType.EMBEDL_ONNXRUNTIME) def _compile_ort_cli( ctx: HubContext, onnx_path: Path, *, device: str | None = None, input_shape: tuple[int, ...] | None = None, calibration_data: CalibrationData | None = None, calibration_method: CalibrationMethod | None = None, per_channel: bool = False, quantize_io: bool = False, ) -> ONNXRuntimeCompiledModel: """Compile an ONNX model via ``embedl-onnxruntime`` on a remote device. All keyword arguments are pre-resolved by the component system. :param ctx: The execution context with device configuration. :param onnx_path: Local path to the input ONNX model. :param input_shape: Input shape tuple. :param calibration_data: Calibration data for quantization. :param calibration_method: Calibration algorithm. :param per_channel: Per-channel quantization flag. :param quantize_io: I/O quantization flag (ignored on ``embedl-onnxruntime`` devices). :return: :class:`ONNXRuntimeCompiledModel` with the compiled model path. """ # Convert dict calibration data to disk if isinstance(calibration_data, dict): calibration_data = _save_calibration_dict( calibration_data, ctx.artifact_dir / "calibration_data" ) if calibration_method is not None and calibration_data is None: warnings.warn( "'calibration_method' has no effect without " "'calibration_data'. Random calibration will be used.", stacklevel=2, ) if quantize_io: warnings.warn( "'quantize_io' has no effect on embedl-onnxruntime devices " "and will be ignored.", stacklevel=2, ) assert device is not None dev = ctx.devices[device] runner = dev.runner if runner is None: raise ValueError( "embedl_onnxruntime devices require a device with a " "command runner." ) remote_artifact_dir = dev.artifact_dir if remote_artifact_dir is None: raise ValueError("No remote artifact directory available.") # Log the input model artifact ctx.client.log_artifact( onnx_path, name="input", file_name=f"input_{onnx_path.name}" ) # Log runtime parameter ctx.client.log_param("$runtime", "onnx") # Step 1: Run quantization on the remote device result = _run_remote_quantization( ctx, runner=runner, onnx_path=onnx_path, remote_artifact_dir=remote_artifact_dir, calibration_data=calibration_data, calibration_method=calibration_method, per_channel=per_channel, ) # Step 2: Retrieve the quantized model local_model_path = _retrieve_remote_model( ctx, runner=runner, result=result, remote_artifact_dir=remote_artifact_dir, local_artifact_dir=ctx.artifact_dir, output_name=f"compiled_{onnx_path.name}", ) # Log the compiled model artifact ctx.client.log_artifact(local_model_path, name="path") return ONNXRuntimeCompiledModel.from_current_run(ctx) # -- CLI helpers ----------------------------------------------------------- def _run_remote_quantization( ctx: HubContext, *, runner: CommandRunner, onnx_path: Path, remote_artifact_dir: RemotePath, calibration_data: Path | None, calibration_method: CalibrationMethod | None, per_channel: bool, ) -> CommandResult: """Upload the model and run ``embedl-onnxruntime`` on the device. :param ctx: The execution context. :param runner: The device command runner. :param onnx_path: Local path to the input ONNX model. :param remote_artifact_dir: Remote directory for artifacts. :param calibration_data: Optional local path to calibration data. :param calibration_method: Calibration algorithm to use. :param per_channel: Whether to use per-channel quantization. :return: The process result from the remote command. :raises CompileError: If the remote command fails. """ # Upload the ONNX model to the remote device remote_onnx = remote_artifact_dir / onnx_path.name runner.put(LocalPath(onnx_path), remote_onnx) # Ensure model path points to an .onnx file model_arg = remote_onnx if model_arg.suffix != ".onnx": model_arg = model_arg / "model.onnx" # Resolve tool path and extra CLI args from device config dev = next(iter(ctx.devices.values())) cfg = dev.get_provider_config(EmbedlONNXRuntimeConfig, ONNXRuntimeCompiler) if cfg is None: cfg = EmbedlONNXRuntimeConfig() ort_path = str(cfg.embedl_onnxruntime_path) cli_args = list(cfg.cli_args) # Determine calibration method calib_method: str if calibration_data is not None: calib_method = calibration_method or "minmax" else: calib_method = "random" # Build the embedl-onnxruntime command run_args: list[str] = [ ort_path, "quantize-static", "--model", str(model_arg), "--calib-method", calib_method, ] # Upload calibration data to the remote device if provided remote_calib_dir: RemotePath | None = None if calibration_data is not None: remote_calib_dir = remote_artifact_dir / "calibration_data" put_directory(runner, LocalPath(calibration_data), remote_calib_dir) run_args.extend(["--calib-dir", str(remote_calib_dir)]) if per_channel: run_args.append("--per-channel") # Append any extra CLI arguments from device metadata run_args.extend(cli_args) # Log parameters ctx.client.log_param("embedl_onnxruntime_path", ort_path) ctx.client.log_param("calib_method", calib_method) ctx.client.log_param("per_channel", str(per_channel)) if cli_args: ctx.client.log_param("cli_args", " ".join(cli_args)) if remote_calib_dir is not None: ctx.client.log_param("calib_dir", str(remote_calib_dir)) # Execute compilation on the remote device local_artifact_dir = ctx.artifact_dir result = run_remote_command( runner, run_args, local_artifact_dir, error_cls=CompileError, error_message="Model compilation failed on the remote device.", ) return result def _retrieve_remote_model( ctx: HubContext, *, runner: CommandRunner, result: CommandResult, remote_artifact_dir: RemotePath, local_artifact_dir: Path, output_name: str = "compiled_model.onnx", ) -> Path: """Parse the quantized model path and download it locally. :param ctx: The execution context. :param runner: The device command runner. :param result: The process result containing stdout. :param remote_artifact_dir: Remote directory for artifacts. :param local_artifact_dir: Local directory for artifacts. :param output_name: Filename for the downloaded model. :return: Local path to the downloaded quantized model. :raises CompileError: If the model path cannot be parsed. """ match = re.search( r"Quantized model was saved to (?P<path>.*)\.", (result.stdout or "").rstrip(), ) if match is None: raise CompileError( "Could not determine quantized model path from " "embedl-onnxruntime output." ) quantized_model_path = RemotePath(match["path"]) # Copy quantized model into the remote artifact directory remote_model_path = remote_artifact_dir / quantized_model_path.name if str(quantized_model_path) != str(remote_model_path): runner.run( ["cp", str(quantized_model_path), str(remote_model_path)], hide=True, ) # Download the compiled model to the local artifact directory local_model_path = local_artifact_dir / output_name runner.get(remote_model_path, local_model_path) return local_model_path