# Copyright (C) 2026 Embedl AB
"""TFLite profiler component.
Supports profiling via:
- **``qai_hub``** devices: Qualcomm AI Hub cloud profiling using the
TensorFlow Lite runtime.
- **``aws``** devices: Embedl device cloud (AWS Device Farm) profiling
for ``.tflite`` models.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
import qai_hub as hub
from embedl_hub._internal.core.compile.tflite import TFLiteCompiledModel
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import ComponentOutput
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.profile.logging import (
log_execution_detail,
log_model_summary,
)
from embedl_hub._internal.core.profile.parsers import parse_qai_hub_profile
from embedl_hub._internal.core.profile.result import ProfileError
from embedl_hub._internal.core.utils.qai_hub_utils import (
log_runtime_info,
resolve_qai_hub_device,
)
from embedl_hub._internal.tracking.device_cloud import TFLiteBenchmarkParams
from embedl_hub._internal.tracking.rest_api import RunType
from embedl_hub._internal.tracking.run_log import LoggedArtifact, LoggedMetric
# ---------------------------------------------------------------------------
# Output dataclass
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class TFLiteProfilingResult(ComponentOutput):
"""Output of a TFLite profiling step.
:param latency: The measured latency in milliseconds.
:param fps: Frames per second derived from the latency.
:param output_file: An optional logged artifact with detailed profile data.
"""
latency: LoggedMetric
fps: LoggedMetric
output_file: LoggedArtifact | None
# ---------------------------------------------------------------------------
# Component class
# ---------------------------------------------------------------------------
[docs]
class TFLiteProfiler(Component):
"""Profile a compiled TFLite model.
Dispatches to a device-specific implementation based on the
configured device type.
"""
run_type = RunType.PROFILE
def __init__(
self,
*,
name: str | None = None,
device: str | None = None,
benchmark_params: TFLiteBenchmarkParams | None = None,
) -> None:
"""Create a new TFLite profiler.
Parameters passed here become the defaults for every call.
They can be overridden per call by passing them to
:meth:`run` directly.
:param name: Display name for this component instance.
:param device: Name of the target device.
:param benchmark_params: Optional TFLite benchmark parameters
(only used by the AWS provider).
"""
super().__init__(
name=name, device=device, benchmark_params=benchmark_params
)
[docs]
def run(
self,
ctx: HubContext,
model: TFLiteCompiledModel,
*,
device: str | None = None,
benchmark_params: TFLiteBenchmarkParams | None = None,
) -> TFLiteProfilingResult:
"""Profile a compiled TFLite model.
Keyword arguments override the defaults set in the constructor.
If a keyword argument is not provided here, the value from the
constructor is used.
:param ctx: The execution context with device configuration.
:param model: The compiled TFLite model (from :class:`TFLiteCompiler`).
:param device: Name of the target device.
:param benchmark_params: Optional TFLite benchmark parameters
(only used by the AWS provider).
:return: A :class:`TFLiteProfilingResult` with latency and FPS metrics.
"""
raise NoProviderError # Replaced by the component system.
# ===================================================================== #
# QAI Hub provider
# ===================================================================== #
@TFLiteProfiler.provider(ProviderType.QAI_HUB)
def _profile_tflite_qai_hub(
ctx: HubContext,
model: TFLiteCompiledModel,
*,
device: str | None = None,
benchmark_params: TFLiteBenchmarkParams | None = None,
) -> TFLiteProfilingResult:
"""Profile a TFLite model using Qualcomm AI Hub.
:param ctx: The execution context with device configuration.
:param model: The compiled TFLite model.
:param device: Name of the target device.
:param benchmark_params: Ignored for QAI Hub.
:return: A :class:`TFLiteProfilingResult` with latency and FPS metrics.
"""
model_local_path = model.path.to_local(ctx).local()
hub_device, device_name = resolve_qai_hub_device(ctx, device)
ctx.client.log_artifact(
model.path,
name="input",
file_name=f"input_{model.path.file_name}",
ctx=ctx,
)
latency_ms, profile_dict = _profile_via_qai_hub(
ctx, model_local_path, hub_device
)
fps = 1000.0 / latency_ms
ctx.client.log_metric("$latency", latency_ms)
ctx.client.log_metric("fps", fps)
# -- log model summary and per-layer execution detail ----------------
summary, execution_detail = parse_qai_hub_profile(profile_dict)
log_model_summary(ctx.client, summary)
log_execution_detail(ctx.client, execution_detail)
profile_path = ctx.artifact_dir / "profile.json"
with open(profile_path, "w", encoding="utf-8") as f:
json.dump(profile_dict, f, indent=2)
ctx.client.log_artifact(profile_path, name="output_file")
return TFLiteProfilingResult.from_current_run(ctx)
# -- QAI Hub helpers -------------------------------------------------------
def _profile_via_qai_hub(
ctx: HubContext,
model_path: Path,
hub_device: hub.Device,
) -> tuple[float, dict]:
"""Submit a profile job to Qualcomm AI Hub and parse results.
:param ctx: The execution context.
:param model_path: Local path to the ``.tflite`` model.
:param hub_device: Target QAI Hub device.
:return: Tuple of ``(latency_ms, profile_dict)``.
:raises ProfileError: If the profile job fails.
"""
try:
job: hub.ProfileJob = hub.submit_profile_job(
model=model_path.as_posix(),
device=hub_device,
)
except Exception as error:
raise ProfileError(
"Failed to submit profile job to Qualcomm AI Hub."
) from error
ctx.client.log_param("$qai_hub_job_id", job.job_id)
ctx.client.log_param("qai_hub_profile_job_id", job.job_id)
try:
profile = job.download_profile()
except Exception as error:
raise ProfileError(
"Failed to download profile from Qualcomm AI Hub."
) from error
log_runtime_info(ctx.client, job.job_id, error_class=ProfileError)
summary = profile.get("execution_summary", {})
estimated_time_us = summary.get("estimated_inference_time")
if estimated_time_us is None:
raise ProfileError(
"Profile result does not contain estimated_inference_time."
)
latency_ms = float(estimated_time_us) / 1000.0
return latency_ms, profile
# ===================================================================== #
# AWS (Embedl device cloud) provider
# ===================================================================== #
_AWS_RUNTIME = "TensorFlow Lite"
@TFLiteProfiler.provider(ProviderType.AWS)
def _profile_tflite_aws(
ctx: HubContext,
model: TFLiteCompiledModel,
*,
device: str | None = None,
benchmark_params: TFLiteBenchmarkParams | None = None,
) -> TFLiteProfilingResult:
"""Profile a TFLite model on the Embedl device cloud (AWS Device Farm).
:param ctx: The execution context with device configuration.
:param model: The compiled TFLite model.
:param device: Name of the target device.
:param benchmark_params: Optional TFLite benchmark parameters.
:return: A :class:`TFLiteProfilingResult` with latency and FPS metrics.
"""
# -- resolve model path (auto-transfer to local) -------------------
model_local_path = model.path.to_local(ctx).local()
# -- resolve device name -------------------------------------------
assert device is not None
dev = ctx.devices[device]
device_name = dev.spec.device_name
if device_name is None:
raise ValueError("Device spec must have a device_name.")
# -- log input artifact ---------------------------------------------
ctx.client.log_artifact(
model.path,
name="input",
file_name=f"input_{model.path.file_name}",
ctx=ctx,
)
ctx.client.log_param("$device", device_name)
ctx.client.log_param("$runtime", _AWS_RUNTIME)
if benchmark_params is not None:
for key, value in benchmark_params.to_dict().items():
ctx.client.log_param(key, value)
# -- submit benchmark job to Embedl device cloud -------------------
latency_ms, summary, execution_detail, artifacts = _profile_via_aws(
ctx, model_local_path, device_name, benchmark_params=benchmark_params
)
fps = 1000.0 / latency_ms
ctx.client.log_metric("$latency", latency_ms)
ctx.client.log_metric("fps", fps)
# -- log model summary and per-layer execution detail ----------------
log_model_summary(ctx.client, summary)
log_execution_detail(ctx.client, execution_detail)
# -- log summary as profile artifact --------------------------------
profile_path = ctx.artifact_dir / "profile.json"
with open(profile_path, "w", encoding="utf-8") as f:
json.dump(summary, f, indent=2)
ctx.client.log_artifact(profile_path, name="output_file")
# -- log individual benchmark artifacts ----------------------------
for artifact_path in artifacts:
ctx.client.log_artifact(artifact_path)
return TFLiteProfilingResult.from_current_run(ctx)
# -- AWS helpers -----------------------------------------------------------
def _profile_via_aws(
ctx: HubContext,
model_path: Path,
device_name: str,
benchmark_params: TFLiteBenchmarkParams | None = None,
) -> tuple[float, dict, list[dict], list[Path]]:
"""Submit a benchmark job to the Embedl device cloud and parse results.
:param ctx: The execution context.
:param model_path: Local path to the ``.tflite`` model.
:param device_name: Device name for the Embedl device cloud.
:param benchmark_params: Optional TFLite benchmark parameters.
:return: Tuple of ``(latency_ms, summary_dict, execution_detail,
artifact_paths)``.
:raises ProfileError: If the benchmark job fails.
"""
try:
job = ctx.client.submit_benchmark_job(
model_path=model_path,
device=device_name,
benchmark_params=benchmark_params,
)
except Exception as error:
raise ProfileError(
"Failed to submit benchmark job to the Embedl device cloud."
) from error
try:
result = job.download_results(artifacts_dir=ctx.artifact_dir)
except Exception as error:
raise ProfileError(
"Failed to download benchmark results from the Embedl device cloud."
) from error
summary = result.summary
execution_detail = result.execution_detail
latency_ms = summary.get("mean_ms")
if latency_ms is None:
raise ProfileError(
"Benchmark result does not contain 'mean_ms' latency."
)
latency_ms = float(latency_ms)
artifacts = result.artifacts or []
return latency_ms, summary, execution_detail, artifacts