# Copyright (C) 2026 Embedl AB
"""ONNX Runtime profiler component.
Supports profiling via:
- **``qai_hub``** devices: Qualcomm AI Hub cloud profiling.
- **``embedl_onnxruntime``** devices: Remote profiling using
``embedl-onnxruntime`` over SSH.
"""
from __future__ import annotations
import json
import logging
import re
import shlex
import warnings
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
import qai_hub as hub
from packaging.version import InvalidVersion, Version
logger = logging.getLogger(__name__)
#: Highest onnxruntime version whose ``ONNXRUNTIME`` profiling output
#: is compatible with the profile-file parser.
_MAX_ORT_VERSION_FOR_NATIVE_PROFILING = Version("1.20.1")
from embedl_hub._internal.core.compile.onnxruntime import (
ONNXRuntimeCompiledModel,
)
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import ComponentOutput
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.device.abc import CommandRunner
from embedl_hub._internal.core.device.commands import run_remote_command
from embedl_hub._internal.core.device.config.embedl_onnxruntime import (
EmbedlONNXRuntimeConfig,
)
from embedl_hub._internal.core.profile.logging import (
log_execution_detail,
log_model_summary,
)
from embedl_hub._internal.core.profile.parsers import (
parse_ort_profile,
parse_qai_hub_profile,
)
from embedl_hub._internal.core.profile.result import (
ProfileError,
ProfilingMethod,
)
from embedl_hub._internal.core.types import RemotePath
from embedl_hub._internal.core.utils.onnx_utils import (
maybe_package_onnx_folder_to_file,
)
from embedl_hub._internal.core.utils.qai_hub_utils import (
log_runtime_info,
resolve_qai_hub_device,
)
from embedl_hub._internal.tracking.rest_api import RunType
from embedl_hub._internal.tracking.run_log import LoggedArtifact, LoggedMetric
#: Backward-compatible alias for the old ``ONNXRuntimeProfilingMethod``.
ONNXRuntimeProfilingMethod = ProfilingMethod
#: Maps :class:`ProfilingMethod` values to the ``--profiling-method``
#: string expected by the ``embedl-onnxruntime`` CLI.
_ORT_CLI_METHOD: dict[ProfilingMethod, str] = {
ProfilingMethod.PYTHON: "python",
ProfilingMethod.LAYERWISE: "onnxruntime",
}
def _parse_latency_from_stdout(stdout: str) -> float:
"""Extract the average inference time from ``embedl-onnxruntime`` output.
:param stdout: Standard output text from ``embedl-onnxruntime``.
:return: Average inference latency in milliseconds.
:raises ProfileError: If the latency cannot be parsed from output.
"""
match = re.search(
r"Average inference time:\s+([\d.]+)\s+ms",
stdout,
)
if match is None:
raise ProfileError(
"Could not determine average inference time from "
"embedl-onnxruntime output."
)
return float(match.group(1))
[docs]
@dataclass(frozen=True)
class ONNXRuntimeProfilingResult(ComponentOutput):
"""Output from the ONNXRuntimeProfiler component.
Extends :class:`ComponentOutput` with profiling-specific fields.
.. attribute:: latency
The average inference latency in milliseconds.
.. attribute:: fps
The inferred frames per second.
.. attribute:: output_file
The artifact containing the JSON profile, if available.
"""
latency: LoggedMetric
fps: LoggedMetric
output_file: LoggedArtifact | None
[docs]
class ONNXRuntimeProfiler(Component):
"""Component that profiles ONNX models.
Supports two device types:
- ``qai_hub`` devices: Profile via Qualcomm AI Hub cloud service.
- ``embedl-onnxruntime`` devices: Profile via
``embedl-onnxruntime measure-latency`` on a remote device
over SSH.
Device-specific parameters (``embedl_onnxruntime_path``,
``cli_args``) are configured via
:class:`~embedl_hub._internal.core.device.config.embedl_onnxruntime.EmbedlONNXRuntimeConfig`
on the device.
"""
run_type = RunType.PROFILE
def __init__(
self,
*,
name: str | None = None,
device: str | None = None,
runs: int = 100,
burn_ins: int = 10,
cold_starts: int = 1,
profiling_method: ProfilingMethod = ProfilingMethod.LAYERWISE,
) -> None:
"""Create a new ONNX Runtime profiler.
Parameters passed here become the defaults for every call.
:param name: Display name for this component instance.
:param device: Name of the target device.
:param runs: Number of inference iterations.
*Only used on* ``embedl_onnxruntime`` *devices.*
:param burn_ins: Number of warm-up iterations before
measurement. *Only used on* ``embedl_onnxruntime``
*devices.*
:param cold_starts: Number of cold-start iterations.
*Only used on* ``embedl_onnxruntime`` *devices.*
:param profiling_method: Method specifying how to measure
execution time. See :class:`ProfilingMethod`.
*Only used on* ``embedl_onnxruntime`` *devices.*
"""
super().__init__(
name=name,
device=device,
runs=runs,
burn_ins=burn_ins,
cold_starts=cold_starts,
profiling_method=profiling_method,
)
[docs]
def run(
self,
ctx: HubContext,
model: ONNXRuntimeCompiledModel,
*,
device: str | None = None,
runs: int = 100,
burn_ins: int = 10,
cold_starts: int = 1,
profiling_method: ProfilingMethod = ProfilingMethod.LAYERWISE,
) -> ONNXRuntimeProfilingResult:
"""Profile an ONNX model via ``embedl-onnxruntime measure-latency``.
Keyword arguments override the defaults set in the constructor.
If a keyword argument is not provided here, the value from the
constructor is used.
:param ctx: The execution context with device configuration.
:param model: An :class:`ONNXRuntimeCompiledModel` whose ``path``
artifact points to an ONNX model.
:param device: Name of the target device.
:param runs: Number of inference iterations.
*Only used on* ``embedl_onnxruntime`` *devices.*
:param burn_ins: Number of warm-up iterations before
measurement. *Only used on* ``embedl_onnxruntime``
*devices.*
:param cold_starts: Number of cold-start iterations.
*Only used on* ``embedl_onnxruntime`` *devices.*
:param profiling_method: Method specifying how to measure
execution time. See :class:`ProfilingMethod`.
*Only used on* ``embedl_onnxruntime`` *devices.*
:return: An :class:`ONNXRuntimeProfilingResult` with latency
and FPS metrics.
"""
raise NoProviderError # Replaced by the component system.
# --------------------------------------------------------------------- #
# QAI Hub provider
# --------------------------------------------------------------------- #
@ONNXRuntimeProfiler.provider(ProviderType.QAI_HUB)
def _profile_ort_qai_hub(
ctx: HubContext,
model: ONNXRuntimeCompiledModel,
*,
device: str | None = None,
runs: int = 100,
burn_ins: int = 10,
cold_starts: int = 1,
profiling_method: ProfilingMethod = ProfilingMethod.LAYERWISE,
) -> ONNXRuntimeProfilingResult:
"""Profile an ONNX model using Qualcomm AI Hub.
Submits a profile job to QAI Hub and downloads the profiling results.
The ``runs``, ``burn_ins``, ``cold_starts``, ``profiling_method``,
and ``cli_args`` parameters are ignored — QAI Hub controls the
benchmark configuration.
All keyword arguments are pre-resolved by the component system.
:param ctx: The execution context with device configuration.
:param model: An :class:`ONNXRuntimeCompiledModel`.
:param device: Ignored for QAI Hub.
:param runs: Ignored for QAI Hub.
:param burn_ins: Ignored for QAI Hub.
:param cold_starts: Ignored for QAI Hub.
:param profiling_method: Ignored for QAI Hub.
:return: An :class:`ONNXRuntimeProfilingResult`.
"""
# -- resolve model path (auto-transfer to local) -------------------
model_local_path = model.path.to_local(ctx).local()
# -- validate device ------------------------------------------------
hub_device, device_name = resolve_qai_hub_device(ctx, device)
# -- log input artifact ---------------------------------------------
ctx.client.log_artifact(
model.path,
name="input",
file_name=f"input_{model.path.file_name}",
ctx=ctx,
)
# -- profile via QAI Hub -------------------------------------------
latency_ms, profile_dict = _profile_via_qai_hub(
ctx, model_local_path, hub_device
)
# -- compute FPS ----------------------------------------------------
fps = 1000.0 / latency_ms
# -- log output metrics and artifacts --------------------------------
ctx.client.log_metric("$latency", latency_ms)
ctx.client.log_metric("fps", fps)
# -- log model summary and per-layer execution detail ----------------
summary, execution_detail = parse_qai_hub_profile(profile_dict)
log_model_summary(ctx.client, summary)
log_execution_detail(ctx.client, execution_detail)
# Save the full profile as an artifact
profile_path = ctx.artifact_dir / "profile.json"
with open(profile_path, "w", encoding="utf-8") as f:
json.dump(profile_dict, f, indent=2)
ctx.client.log_artifact(profile_path, name="output_file")
return ONNXRuntimeProfilingResult.from_current_run(ctx)
# -- QAI Hub helpers -------------------------------------------------------
def _profile_via_qai_hub(
ctx: HubContext,
model_path: Path,
hub_device: hub.Device,
) -> tuple[float, dict]:
"""Submit a profile job to Qualcomm AI Hub and parse results.
:param ctx: The execution context.
:param model_path: Local path to the ONNX model.
:param hub_device: Target QAI Hub device.
:return: A ``(latency_ms, full_profile_dict)`` tuple.
:raises ProfileError: If the profile job fails.
"""
# Package directory models into a single .onnx file if needed
with TemporaryDirectory() as tmpdir:
packaged_model = maybe_package_onnx_folder_to_file(model_path, tmpdir)
# Submit profile job
try:
job: hub.ProfileJob = hub.submit_profile_job(
model=packaged_model,
device=hub_device,
)
except Exception as error:
raise ProfileError(
"Failed to submit profile job to Qualcomm AI Hub."
) from error
ctx.client.log_param("$qai_hub_job_id", job.job_id)
ctx.client.log_param("qai_hub_profile_job_id", job.job_id)
# Download profile results
try:
profile = job.download_profile()
except Exception as error:
raise ProfileError(
"Failed to download profile from Qualcomm AI Hub."
) from error
# Parse runtime info
log_runtime_info(ctx.client, job.job_id, error_class=ProfileError)
# Parse latency from profile
summary = profile.get("execution_summary", {})
estimated_time_us = summary.get("estimated_inference_time")
if estimated_time_us is None:
raise ProfileError(
"Profile result does not contain estimated_inference_time."
)
latency_ms = float(estimated_time_us) / 1000.0
return latency_ms, profile
# --------------------------------------------------------------------- #
# embedl_onnxruntime provider
# --------------------------------------------------------------------- #
@ONNXRuntimeProfiler.provider(ProviderType.EMBEDL_ONNXRUNTIME)
def _profile_ort_cli(
ctx: HubContext,
model: ONNXRuntimeCompiledModel,
*,
device: str | None = None,
runs: int = 100,
burn_ins: int = 10,
cold_starts: int = 1,
profiling_method: ProfilingMethod = ProfilingMethod.LAYERWISE,
) -> ONNXRuntimeProfilingResult:
"""Profile an ONNX model via ``embedl-onnxruntime`` on a remote device.
All keyword arguments are pre-resolved by the component system.
"""
# -- validate device ------------------------------------------------
assert device is not None
dev = ctx.devices[device]
runner = dev.runner
if runner is None:
raise ValueError(
"embedl_onnxruntime devices require a device with a "
"command runner."
)
remote_artifact_dir = dev.artifact_dir
if remote_artifact_dir is None:
raise ValueError("No remote artifact directory available.")
# -- resolve model path (auto-transfer) ----------------------------
assert device is not None
remote_model = model.path.to_remote(ctx, device_name=device).remote()
# -- resolve provider config ----------------------------------------
cfg = dev.get_provider_config(EmbedlONNXRuntimeConfig, ONNXRuntimeProfiler)
if cfg is None:
cfg = EmbedlONNXRuntimeConfig()
ort_path = str(cfg.embedl_onnxruntime_path)
effective_cli_args: list[str] = list(cfg.cli_args)
# -- log input artifact ---------------------------------------------
ctx.client.log_artifact(
model.path,
name="input",
file_name=f"input_{model.path.file_name}",
ctx=ctx,
)
# -- profile --------------------------------------------------------
latency_ms, ort_profile = _run_remote_profiling(
ctx,
runner=runner,
remote_model=remote_model,
remote_artifact_dir=remote_artifact_dir,
ort_path=ort_path,
runs=runs,
burn_ins=burn_ins,
cold_starts=cold_starts,
profiling_method=profiling_method,
cli_args=effective_cli_args,
)
# -- compute FPS ----------------------------------------------------
fps = 1000.0 / latency_ms
# -- log output metrics and artifacts --------------------------------
ctx.client.log_metric("$latency", latency_ms)
ctx.client.log_metric("fps", fps)
# -- log model summary and per-layer execution detail ----------------
if ort_profile is not None:
summary, execution_detail = parse_ort_profile(ort_profile)
log_model_summary(ctx.client, summary)
log_execution_detail(ctx.client, execution_detail)
profile_path = ctx.artifact_dir / "profile.json"
if profile_path.exists():
ctx.client.log_artifact(profile_path, name="output_file")
return ONNXRuntimeProfilingResult.from_current_run(ctx)
# --------------------------------------------------------------------- #
# internal helpers
# --------------------------------------------------------------------- #
def _get_remote_ort_version(
runner: CommandRunner,
ort_path: str,
) -> Version | None:
"""Query the onnxruntime version installed on the remote device.
Derives the Python interpreter path from `ort_path` so that the
correct virtual-environment Python is used (e.g.
``/root/.venv/bin/embedl-onnxruntime`` → ``/root/.venv/bin/python3``).
:param runner: The device command runner.
:param ort_path: Path to ``embedl-onnxruntime`` on the remote device.
:return: The remote onnxruntime version, or ``None`` if unavailable.
"""
try:
ort_bin_dir = str(Path(ort_path).parent)
python_path = (
f"{ort_bin_dir}/python3" if ort_bin_dir != "." else "python3"
)
snippet = "import onnxruntime; print(onnxruntime.__version__)"
result = runner.run(
[python_path, "-c", shlex.quote(snippet)],
hide=True,
)
raw = (result.stdout or "").strip()
return Version(raw)
except (Exception, InvalidVersion):
return None
def _resolve_profiling_method(
requested: ProfilingMethod,
runner: CommandRunner,
ort_path: str,
) -> ProfilingMethod:
"""Return the effective profiling method, falling back to
:attr:`~ProfilingMethod.PYTHON` when the remote onnxruntime version
is too new for the native profiling parser or when
:attr:`~ProfilingMethod.MODEL` is requested (ONNX Runtime has no
native model-level timing mode).
:param requested: The profiling method requested by the caller.
:param runner: The device command runner.
:param ort_path: Path to ``embedl-onnxruntime`` on the remote device.
:return: The effective :class:`ProfilingMethod` to use.
"""
if requested == ProfilingMethod.MODEL:
warnings.warn(
"ONNX Runtime does not provide a native model-level timing "
"mode distinct from wall-clock measurement. "
"Falling back to PYTHON profiling method.",
stacklevel=3,
)
return ProfilingMethod.PYTHON
if requested != ProfilingMethod.LAYERWISE:
return requested
remote_version = _get_remote_ort_version(runner, ort_path)
if remote_version is None:
# Cannot determine — keep the user's choice.
return requested
if remote_version > _MAX_ORT_VERSION_FOR_NATIVE_PROFILING:
warnings.warn(
f"onnxruntime {remote_version} on the remote device is newer "
f"than {_MAX_ORT_VERSION_FOR_NATIVE_PROFILING}; the LAYERWISE "
"profiling method is not supported for this version. "
"Falling back to PYTHON profiling method.",
stacklevel=2,
)
return ProfilingMethod.PYTHON
return requested
def _run_remote_profiling(
ctx: HubContext,
*,
runner: CommandRunner,
remote_model: RemotePath,
remote_artifact_dir: RemotePath,
ort_path: str,
runs: int,
burn_ins: int,
cold_starts: int,
profiling_method: ProfilingMethod,
cli_args: list[str],
) -> tuple[float, dict | None]:
"""Run ``embedl-onnxruntime measure-latency`` on a remote device and
parse results.
:param ctx: The execution context.
:param runner: The device command runner.
:param remote_model: Remote path to the ONNX model on the device.
:param remote_artifact_dir: Remote directory for artifacts.
:param ort_path: Path to ``embedl-onnxruntime`` on the remote device.
:param runs: Number of inference iterations.
:param burn_ins: Number of warm-up iterations.
:param cold_starts: Number of cold-start iterations.
:param profiling_method: Profiling method to use.
:param cli_args: Extra CLI arguments.
:return: Tuple of ``(latency_ms, ort_profile_dict_or_None)``.
:raises ProfileError: If the remote command fails.
"""
profiling_method = _resolve_profiling_method(
profiling_method, runner, ort_path
)
# Map the unified ProfilingMethod to the CLI string expected by
# ``embedl-onnxruntime``.
cli_method = _ORT_CLI_METHOD[profiling_method]
# Build the embedl-onnxruntime command
run_args: list[str] = [
ort_path,
"measure-latency",
"--model",
str(remote_model),
"--runs",
str(runs),
"--burn-ins",
str(burn_ins),
"--cold-starts",
str(cold_starts),
"--profiling-method",
cli_method,
]
# If using LAYERWISE method, request a JSON profile output
if profiling_method == ProfilingMethod.LAYERWISE:
remote_profile = remote_artifact_dir / "profile.json"
run_args.extend(["--output-json-file", str(remote_profile)])
# Append extra CLI arguments
run_args.extend(cli_args)
# Log parameters
ctx.client.log_param("$runtime", "onnx")
ctx.client.log_param("embedl_onnxruntime_path", ort_path)
ctx.client.log_param("runs", str(runs))
ctx.client.log_param("burn_ins", str(burn_ins))
ctx.client.log_param("cold_starts", str(cold_starts))
ctx.client.log_param("profiling_method", profiling_method.value)
if cli_args:
ctx.client.log_param("cli_args", " ".join(cli_args))
# Execute on the remote device
local_artifact_dir = ctx.artifact_dir
result = run_remote_command(
runner,
run_args,
local_artifact_dir,
error_cls=ProfileError,
error_message="ONNX Runtime profiling failed on the remote device.",
)
stdout = result.stdout or ""
# Download the JSON profile if using LAYERWISE method
ort_profile: dict | None = None
if profiling_method == ProfilingMethod.LAYERWISE:
local_profile = local_artifact_dir / "profile.json"
runner.get(
remote_artifact_dir / "profile.json",
local_profile,
)
with open(local_profile, encoding="utf-8") as f:
ort_profile = json.load(f)
# Parse the latency from stdout
return _parse_latency_from_stdout(stdout), ort_profile