Source code for embedl_hub._internal.core.profile.onnxruntime

# Copyright (C) 2026 Embedl AB

"""ONNX Runtime profiler component.

Supports profiling via:
- **``qai_hub``** devices: Qualcomm AI Hub cloud profiling.
- **``embedl_onnxruntime``** devices: Remote profiling using
  ``embedl-onnxruntime`` over SSH.
"""

from __future__ import annotations

import json
import logging
import re
import shlex
import warnings
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory

import qai_hub as hub
from packaging.version import InvalidVersion, Version

logger = logging.getLogger(__name__)

#: Highest onnxruntime version whose ``ONNXRUNTIME`` profiling output
#: is compatible with the profile-file parser.
_MAX_ORT_VERSION_FOR_NATIVE_PROFILING = Version("1.20.1")

from embedl_hub._internal.core.compile.onnxruntime import (
    ONNXRuntimeCompiledModel,
)
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import ComponentOutput
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.device.abc import CommandRunner
from embedl_hub._internal.core.device.commands import run_remote_command
from embedl_hub._internal.core.device.config.embedl_onnxruntime import (
    EmbedlONNXRuntimeConfig,
)
from embedl_hub._internal.core.profile.logging import (
    log_execution_detail,
    log_model_summary,
)
from embedl_hub._internal.core.profile.parsers import (
    parse_ort_profile,
    parse_qai_hub_profile,
)
from embedl_hub._internal.core.profile.result import (
    ProfileError,
    ProfilingMethod,
)
from embedl_hub._internal.core.types import RemotePath
from embedl_hub._internal.core.utils.onnx_utils import (
    maybe_package_onnx_folder_to_file,
)
from embedl_hub._internal.core.utils.qai_hub_utils import (
    log_runtime_info,
    resolve_qai_hub_device,
)
from embedl_hub._internal.tracking.rest_api import RunType
from embedl_hub._internal.tracking.run_log import LoggedArtifact, LoggedMetric

#: Backward-compatible alias for the old ``ONNXRuntimeProfilingMethod``.
ONNXRuntimeProfilingMethod = ProfilingMethod

#: Maps :class:`ProfilingMethod` values to the ``--profiling-method``
#: string expected by the ``embedl-onnxruntime`` CLI.
_ORT_CLI_METHOD: dict[ProfilingMethod, str] = {
    ProfilingMethod.PYTHON: "python",
    ProfilingMethod.LAYERWISE: "onnxruntime",
}


def _parse_latency_from_stdout(stdout: str) -> float:
    """Extract the average inference time from ``embedl-onnxruntime`` output.

    :param stdout: Standard output text from ``embedl-onnxruntime``.
    :return: Average inference latency in milliseconds.
    :raises ProfileError: If the latency cannot be parsed from output.
    """
    match = re.search(
        r"Average inference time:\s+([\d.]+)\s+ms",
        stdout,
    )
    if match is None:
        raise ProfileError(
            "Could not determine average inference time from "
            "embedl-onnxruntime output."
        )
    return float(match.group(1))


[docs] @dataclass(frozen=True) class ONNXRuntimeProfilingResult(ComponentOutput): """Output from the ONNXRuntimeProfiler component. Extends :class:`ComponentOutput` with profiling-specific fields. .. attribute:: latency The average inference latency in milliseconds. .. attribute:: fps The inferred frames per second. .. attribute:: output_file The artifact containing the JSON profile, if available. """ latency: LoggedMetric fps: LoggedMetric output_file: LoggedArtifact | None
[docs] class ONNXRuntimeProfiler(Component): """Component that profiles ONNX models. Supports two device types: - ``qai_hub`` devices: Profile via Qualcomm AI Hub cloud service. - ``embedl-onnxruntime`` devices: Profile via ``embedl-onnxruntime measure-latency`` on a remote device over SSH. Device-specific parameters (``embedl_onnxruntime_path``, ``cli_args``) are configured via :class:`~embedl_hub._internal.core.device.config.embedl_onnxruntime.EmbedlONNXRuntimeConfig` on the device. """ run_type = RunType.PROFILE def __init__( self, *, name: str | None = None, device: str | None = None, runs: int = 100, burn_ins: int = 10, cold_starts: int = 1, profiling_method: ProfilingMethod = ProfilingMethod.LAYERWISE, ) -> None: """Create a new ONNX Runtime profiler. Parameters passed here become the defaults for every call. :param name: Display name for this component instance. :param device: Name of the target device. :param runs: Number of inference iterations. *Only used on* ``embedl_onnxruntime`` *devices.* :param burn_ins: Number of warm-up iterations before measurement. *Only used on* ``embedl_onnxruntime`` *devices.* :param cold_starts: Number of cold-start iterations. *Only used on* ``embedl_onnxruntime`` *devices.* :param profiling_method: Method specifying how to measure execution time. See :class:`ProfilingMethod`. *Only used on* ``embedl_onnxruntime`` *devices.* """ super().__init__( name=name, device=device, runs=runs, burn_ins=burn_ins, cold_starts=cold_starts, profiling_method=profiling_method, )
[docs] def run( self, ctx: HubContext, model: ONNXRuntimeCompiledModel, *, device: str | None = None, runs: int = 100, burn_ins: int = 10, cold_starts: int = 1, profiling_method: ProfilingMethod = ProfilingMethod.LAYERWISE, ) -> ONNXRuntimeProfilingResult: """Profile an ONNX model via ``embedl-onnxruntime measure-latency``. Keyword arguments override the defaults set in the constructor. If a keyword argument is not provided here, the value from the constructor is used. :param ctx: The execution context with device configuration. :param model: An :class:`ONNXRuntimeCompiledModel` whose ``path`` artifact points to an ONNX model. :param device: Name of the target device. :param runs: Number of inference iterations. *Only used on* ``embedl_onnxruntime`` *devices.* :param burn_ins: Number of warm-up iterations before measurement. *Only used on* ``embedl_onnxruntime`` *devices.* :param cold_starts: Number of cold-start iterations. *Only used on* ``embedl_onnxruntime`` *devices.* :param profiling_method: Method specifying how to measure execution time. See :class:`ProfilingMethod`. *Only used on* ``embedl_onnxruntime`` *devices.* :return: An :class:`ONNXRuntimeProfilingResult` with latency and FPS metrics. """ raise NoProviderError # Replaced by the component system.
# --------------------------------------------------------------------- # # QAI Hub provider # --------------------------------------------------------------------- # @ONNXRuntimeProfiler.provider(ProviderType.QAI_HUB) def _profile_ort_qai_hub( ctx: HubContext, model: ONNXRuntimeCompiledModel, *, device: str | None = None, runs: int = 100, burn_ins: int = 10, cold_starts: int = 1, profiling_method: ProfilingMethod = ProfilingMethod.LAYERWISE, ) -> ONNXRuntimeProfilingResult: """Profile an ONNX model using Qualcomm AI Hub. Submits a profile job to QAI Hub and downloads the profiling results. The ``runs``, ``burn_ins``, ``cold_starts``, ``profiling_method``, and ``cli_args`` parameters are ignored — QAI Hub controls the benchmark configuration. All keyword arguments are pre-resolved by the component system. :param ctx: The execution context with device configuration. :param model: An :class:`ONNXRuntimeCompiledModel`. :param device: Ignored for QAI Hub. :param runs: Ignored for QAI Hub. :param burn_ins: Ignored for QAI Hub. :param cold_starts: Ignored for QAI Hub. :param profiling_method: Ignored for QAI Hub. :return: An :class:`ONNXRuntimeProfilingResult`. """ # -- resolve model path (auto-transfer to local) ------------------- model_local_path = model.path.to_local(ctx).local() # -- validate device ------------------------------------------------ hub_device, device_name = resolve_qai_hub_device(ctx, device) # -- log input artifact --------------------------------------------- ctx.client.log_artifact( model.path, name="input", file_name=f"input_{model.path.file_name}", ctx=ctx, ) # -- profile via QAI Hub ------------------------------------------- latency_ms, profile_dict = _profile_via_qai_hub( ctx, model_local_path, hub_device ) # -- compute FPS ---------------------------------------------------- fps = 1000.0 / latency_ms # -- log output metrics and artifacts -------------------------------- ctx.client.log_metric("$latency", latency_ms) ctx.client.log_metric("fps", fps) # -- log model summary and per-layer execution detail ---------------- summary, execution_detail = parse_qai_hub_profile(profile_dict) log_model_summary(ctx.client, summary) log_execution_detail(ctx.client, execution_detail) # Save the full profile as an artifact profile_path = ctx.artifact_dir / "profile.json" with open(profile_path, "w", encoding="utf-8") as f: json.dump(profile_dict, f, indent=2) ctx.client.log_artifact(profile_path, name="output_file") return ONNXRuntimeProfilingResult.from_current_run(ctx) # -- QAI Hub helpers ------------------------------------------------------- def _profile_via_qai_hub( ctx: HubContext, model_path: Path, hub_device: hub.Device, ) -> tuple[float, dict]: """Submit a profile job to Qualcomm AI Hub and parse results. :param ctx: The execution context. :param model_path: Local path to the ONNX model. :param hub_device: Target QAI Hub device. :return: A ``(latency_ms, full_profile_dict)`` tuple. :raises ProfileError: If the profile job fails. """ # Package directory models into a single .onnx file if needed with TemporaryDirectory() as tmpdir: packaged_model = maybe_package_onnx_folder_to_file(model_path, tmpdir) # Submit profile job try: job: hub.ProfileJob = hub.submit_profile_job( model=packaged_model, device=hub_device, ) except Exception as error: raise ProfileError( "Failed to submit profile job to Qualcomm AI Hub." ) from error ctx.client.log_param("$qai_hub_job_id", job.job_id) ctx.client.log_param("qai_hub_profile_job_id", job.job_id) # Download profile results try: profile = job.download_profile() except Exception as error: raise ProfileError( "Failed to download profile from Qualcomm AI Hub." ) from error # Parse runtime info log_runtime_info(ctx.client, job.job_id, error_class=ProfileError) # Parse latency from profile summary = profile.get("execution_summary", {}) estimated_time_us = summary.get("estimated_inference_time") if estimated_time_us is None: raise ProfileError( "Profile result does not contain estimated_inference_time." ) latency_ms = float(estimated_time_us) / 1000.0 return latency_ms, profile # --------------------------------------------------------------------- # # embedl_onnxruntime provider # --------------------------------------------------------------------- # @ONNXRuntimeProfiler.provider(ProviderType.EMBEDL_ONNXRUNTIME) def _profile_ort_cli( ctx: HubContext, model: ONNXRuntimeCompiledModel, *, device: str | None = None, runs: int = 100, burn_ins: int = 10, cold_starts: int = 1, profiling_method: ProfilingMethod = ProfilingMethod.LAYERWISE, ) -> ONNXRuntimeProfilingResult: """Profile an ONNX model via ``embedl-onnxruntime`` on a remote device. All keyword arguments are pre-resolved by the component system. """ # -- validate device ------------------------------------------------ assert device is not None dev = ctx.devices[device] runner = dev.runner if runner is None: raise ValueError( "embedl_onnxruntime devices require a device with a " "command runner." ) remote_artifact_dir = dev.artifact_dir if remote_artifact_dir is None: raise ValueError("No remote artifact directory available.") # -- resolve model path (auto-transfer) ---------------------------- assert device is not None remote_model = model.path.to_remote(ctx, device_name=device).remote() # -- resolve provider config ---------------------------------------- cfg = dev.get_provider_config(EmbedlONNXRuntimeConfig, ONNXRuntimeProfiler) if cfg is None: cfg = EmbedlONNXRuntimeConfig() ort_path = str(cfg.embedl_onnxruntime_path) effective_cli_args: list[str] = list(cfg.cli_args) # -- log input artifact --------------------------------------------- ctx.client.log_artifact( model.path, name="input", file_name=f"input_{model.path.file_name}", ctx=ctx, ) # -- profile -------------------------------------------------------- latency_ms, ort_profile = _run_remote_profiling( ctx, runner=runner, remote_model=remote_model, remote_artifact_dir=remote_artifact_dir, ort_path=ort_path, runs=runs, burn_ins=burn_ins, cold_starts=cold_starts, profiling_method=profiling_method, cli_args=effective_cli_args, ) # -- compute FPS ---------------------------------------------------- fps = 1000.0 / latency_ms # -- log output metrics and artifacts -------------------------------- ctx.client.log_metric("$latency", latency_ms) ctx.client.log_metric("fps", fps) # -- log model summary and per-layer execution detail ---------------- if ort_profile is not None: summary, execution_detail = parse_ort_profile(ort_profile) log_model_summary(ctx.client, summary) log_execution_detail(ctx.client, execution_detail) profile_path = ctx.artifact_dir / "profile.json" if profile_path.exists(): ctx.client.log_artifact(profile_path, name="output_file") return ONNXRuntimeProfilingResult.from_current_run(ctx) # --------------------------------------------------------------------- # # internal helpers # --------------------------------------------------------------------- # def _get_remote_ort_version( runner: CommandRunner, ort_path: str, ) -> Version | None: """Query the onnxruntime version installed on the remote device. Derives the Python interpreter path from `ort_path` so that the correct virtual-environment Python is used (e.g. ``/root/.venv/bin/embedl-onnxruntime`` → ``/root/.venv/bin/python3``). :param runner: The device command runner. :param ort_path: Path to ``embedl-onnxruntime`` on the remote device. :return: The remote onnxruntime version, or ``None`` if unavailable. """ try: ort_bin_dir = str(Path(ort_path).parent) python_path = ( f"{ort_bin_dir}/python3" if ort_bin_dir != "." else "python3" ) snippet = "import onnxruntime; print(onnxruntime.__version__)" result = runner.run( [python_path, "-c", shlex.quote(snippet)], hide=True, ) raw = (result.stdout or "").strip() return Version(raw) except (Exception, InvalidVersion): return None def _resolve_profiling_method( requested: ProfilingMethod, runner: CommandRunner, ort_path: str, ) -> ProfilingMethod: """Return the effective profiling method, falling back to :attr:`~ProfilingMethod.PYTHON` when the remote onnxruntime version is too new for the native profiling parser or when :attr:`~ProfilingMethod.MODEL` is requested (ONNX Runtime has no native model-level timing mode). :param requested: The profiling method requested by the caller. :param runner: The device command runner. :param ort_path: Path to ``embedl-onnxruntime`` on the remote device. :return: The effective :class:`ProfilingMethod` to use. """ if requested == ProfilingMethod.MODEL: warnings.warn( "ONNX Runtime does not provide a native model-level timing " "mode distinct from wall-clock measurement. " "Falling back to PYTHON profiling method.", stacklevel=3, ) return ProfilingMethod.PYTHON if requested != ProfilingMethod.LAYERWISE: return requested remote_version = _get_remote_ort_version(runner, ort_path) if remote_version is None: # Cannot determine — keep the user's choice. return requested if remote_version > _MAX_ORT_VERSION_FOR_NATIVE_PROFILING: warnings.warn( f"onnxruntime {remote_version} on the remote device is newer " f"than {_MAX_ORT_VERSION_FOR_NATIVE_PROFILING}; the LAYERWISE " "profiling method is not supported for this version. " "Falling back to PYTHON profiling method.", stacklevel=2, ) return ProfilingMethod.PYTHON return requested def _run_remote_profiling( ctx: HubContext, *, runner: CommandRunner, remote_model: RemotePath, remote_artifact_dir: RemotePath, ort_path: str, runs: int, burn_ins: int, cold_starts: int, profiling_method: ProfilingMethod, cli_args: list[str], ) -> tuple[float, dict | None]: """Run ``embedl-onnxruntime measure-latency`` on a remote device and parse results. :param ctx: The execution context. :param runner: The device command runner. :param remote_model: Remote path to the ONNX model on the device. :param remote_artifact_dir: Remote directory for artifacts. :param ort_path: Path to ``embedl-onnxruntime`` on the remote device. :param runs: Number of inference iterations. :param burn_ins: Number of warm-up iterations. :param cold_starts: Number of cold-start iterations. :param profiling_method: Profiling method to use. :param cli_args: Extra CLI arguments. :return: Tuple of ``(latency_ms, ort_profile_dict_or_None)``. :raises ProfileError: If the remote command fails. """ profiling_method = _resolve_profiling_method( profiling_method, runner, ort_path ) # Map the unified ProfilingMethod to the CLI string expected by # ``embedl-onnxruntime``. cli_method = _ORT_CLI_METHOD[profiling_method] # Build the embedl-onnxruntime command run_args: list[str] = [ ort_path, "measure-latency", "--model", str(remote_model), "--runs", str(runs), "--burn-ins", str(burn_ins), "--cold-starts", str(cold_starts), "--profiling-method", cli_method, ] # If using LAYERWISE method, request a JSON profile output if profiling_method == ProfilingMethod.LAYERWISE: remote_profile = remote_artifact_dir / "profile.json" run_args.extend(["--output-json-file", str(remote_profile)]) # Append extra CLI arguments run_args.extend(cli_args) # Log parameters ctx.client.log_param("$runtime", "onnx") ctx.client.log_param("embedl_onnxruntime_path", ort_path) ctx.client.log_param("runs", str(runs)) ctx.client.log_param("burn_ins", str(burn_ins)) ctx.client.log_param("cold_starts", str(cold_starts)) ctx.client.log_param("profiling_method", profiling_method.value) if cli_args: ctx.client.log_param("cli_args", " ".join(cli_args)) # Execute on the remote device local_artifact_dir = ctx.artifact_dir result = run_remote_command( runner, run_args, local_artifact_dir, error_cls=ProfileError, error_message="ONNX Runtime profiling failed on the remote device.", ) stdout = result.stdout or "" # Download the JSON profile if using LAYERWISE method ort_profile: dict | None = None if profiling_method == ProfilingMethod.LAYERWISE: local_profile = local_artifact_dir / "profile.json" runner.get( remote_artifact_dir / "profile.json", local_profile, ) with open(local_profile, encoding="utf-8") as f: ort_profile = json.load(f) # Parse the latency from stdout return _parse_latency_from_stdout(stdout), ort_profile