Source code for embedl_hub._internal.core.profile.tflite

# Copyright (C) 2026 Embedl AB

"""TFLite profiler component.

Supports profiling via:
- **``qai_hub``** devices: Qualcomm AI Hub cloud profiling using the
  TensorFlow Lite runtime.
- **``aws``** devices: Embedl device cloud (AWS Device Farm) profiling
  for ``.tflite`` models.
"""

from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path

import qai_hub as hub

from embedl_hub._internal.core.compile.tflite import TFLiteCompiledModel
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import ComponentOutput
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.profile.logging import (
    log_execution_detail,
    log_model_summary,
)
from embedl_hub._internal.core.profile.parsers import parse_qai_hub_profile
from embedl_hub._internal.core.profile.result import ProfileError
from embedl_hub._internal.core.utils.qai_hub_utils import (
    log_runtime_info,
    resolve_qai_hub_device,
)
from embedl_hub._internal.tracking.device_cloud import TFLiteBenchmarkParams
from embedl_hub._internal.tracking.rest_api import RunType
from embedl_hub._internal.tracking.run_log import LoggedArtifact, LoggedMetric

# ---------------------------------------------------------------------------
# Output dataclass
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class TFLiteProfilingResult(ComponentOutput): """Output of a TFLite profiling step. :param latency: The measured latency in milliseconds. :param fps: Frames per second derived from the latency. :param output_file: An optional logged artifact with detailed profile data. """ latency: LoggedMetric fps: LoggedMetric output_file: LoggedArtifact | None
# --------------------------------------------------------------------------- # Component class # ---------------------------------------------------------------------------
[docs] class TFLiteProfiler(Component): """Profile a compiled TFLite model. Dispatches to a device-specific implementation based on the configured device type. """ run_type = RunType.PROFILE def __init__( self, *, name: str | None = None, device: str | None = None, benchmark_params: TFLiteBenchmarkParams | None = None, ) -> None: """Create a new TFLite profiler. Parameters passed here become the defaults for every call. They can be overridden per call by passing them to :meth:`run` directly. :param name: Display name for this component instance. :param device: Name of the target device. :param benchmark_params: Optional TFLite benchmark parameters (only used by the AWS provider). """ super().__init__( name=name, device=device, benchmark_params=benchmark_params )
[docs] def run( self, ctx: HubContext, model: TFLiteCompiledModel, *, device: str | None = None, benchmark_params: TFLiteBenchmarkParams | None = None, ) -> TFLiteProfilingResult: """Profile a compiled TFLite model. Keyword arguments override the defaults set in the constructor. If a keyword argument is not provided here, the value from the constructor is used. :param ctx: The execution context with device configuration. :param model: The compiled TFLite model (from :class:`TFLiteCompiler`). :param device: Name of the target device. :param benchmark_params: Optional TFLite benchmark parameters (only used by the AWS provider). :return: A :class:`TFLiteProfilingResult` with latency and FPS metrics. """ raise NoProviderError # Replaced by the component system.
# ===================================================================== # # QAI Hub provider # ===================================================================== # @TFLiteProfiler.provider(ProviderType.QAI_HUB) def _profile_tflite_qai_hub( ctx: HubContext, model: TFLiteCompiledModel, *, device: str | None = None, benchmark_params: TFLiteBenchmarkParams | None = None, ) -> TFLiteProfilingResult: """Profile a TFLite model using Qualcomm AI Hub. :param ctx: The execution context with device configuration. :param model: The compiled TFLite model. :param device: Name of the target device. :param benchmark_params: Ignored for QAI Hub. :return: A :class:`TFLiteProfilingResult` with latency and FPS metrics. """ model_local_path = model.path.to_local(ctx).local() hub_device, device_name = resolve_qai_hub_device(ctx, device) ctx.client.log_artifact( model.path, name="input", file_name=f"input_{model.path.file_name}", ctx=ctx, ) latency_ms, profile_dict = _profile_via_qai_hub( ctx, model_local_path, hub_device ) fps = 1000.0 / latency_ms ctx.client.log_metric("$latency", latency_ms) ctx.client.log_metric("fps", fps) # -- log model summary and per-layer execution detail ---------------- summary, execution_detail = parse_qai_hub_profile(profile_dict) log_model_summary(ctx.client, summary) log_execution_detail(ctx.client, execution_detail) profile_path = ctx.artifact_dir / "profile.json" with open(profile_path, "w", encoding="utf-8") as f: json.dump(profile_dict, f, indent=2) ctx.client.log_artifact(profile_path, name="output_file") return TFLiteProfilingResult.from_current_run(ctx) # -- QAI Hub helpers ------------------------------------------------------- def _profile_via_qai_hub( ctx: HubContext, model_path: Path, hub_device: hub.Device, ) -> tuple[float, dict]: """Submit a profile job to Qualcomm AI Hub and parse results. :param ctx: The execution context. :param model_path: Local path to the ``.tflite`` model. :param hub_device: Target QAI Hub device. :return: Tuple of ``(latency_ms, profile_dict)``. :raises ProfileError: If the profile job fails. """ try: job: hub.ProfileJob = hub.submit_profile_job( model=model_path.as_posix(), device=hub_device, ) except Exception as error: raise ProfileError( "Failed to submit profile job to Qualcomm AI Hub." ) from error ctx.client.log_param("$qai_hub_job_id", job.job_id) ctx.client.log_param("qai_hub_profile_job_id", job.job_id) try: profile = job.download_profile() except Exception as error: raise ProfileError( "Failed to download profile from Qualcomm AI Hub." ) from error log_runtime_info(ctx.client, job.job_id, error_class=ProfileError) summary = profile.get("execution_summary", {}) estimated_time_us = summary.get("estimated_inference_time") if estimated_time_us is None: raise ProfileError( "Profile result does not contain estimated_inference_time." ) latency_ms = float(estimated_time_us) / 1000.0 return latency_ms, profile # ===================================================================== # # AWS (Embedl device cloud) provider # ===================================================================== # _AWS_RUNTIME = "TensorFlow Lite" @TFLiteProfiler.provider(ProviderType.AWS) def _profile_tflite_aws( ctx: HubContext, model: TFLiteCompiledModel, *, device: str | None = None, benchmark_params: TFLiteBenchmarkParams | None = None, ) -> TFLiteProfilingResult: """Profile a TFLite model on the Embedl device cloud (AWS Device Farm). :param ctx: The execution context with device configuration. :param model: The compiled TFLite model. :param device: Name of the target device. :param benchmark_params: Optional TFLite benchmark parameters. :return: A :class:`TFLiteProfilingResult` with latency and FPS metrics. """ # -- resolve model path (auto-transfer to local) ------------------- model_local_path = model.path.to_local(ctx).local() # -- resolve device name ------------------------------------------- assert device is not None dev = ctx.devices[device] device_name = dev.spec.device_name if device_name is None: raise ValueError("Device spec must have a device_name.") # -- log input artifact --------------------------------------------- ctx.client.log_artifact( model.path, name="input", file_name=f"input_{model.path.file_name}", ctx=ctx, ) ctx.client.log_param("$device", device_name) ctx.client.log_param("$runtime", _AWS_RUNTIME) if benchmark_params is not None: for key, value in benchmark_params.to_dict().items(): ctx.client.log_param(key, value) # -- submit benchmark job to Embedl device cloud ------------------- latency_ms, summary, execution_detail, artifacts = _profile_via_aws( ctx, model_local_path, device_name, benchmark_params=benchmark_params ) fps = 1000.0 / latency_ms ctx.client.log_metric("$latency", latency_ms) ctx.client.log_metric("fps", fps) # -- log model summary and per-layer execution detail ---------------- log_model_summary(ctx.client, summary) log_execution_detail(ctx.client, execution_detail) # -- log summary as profile artifact -------------------------------- profile_path = ctx.artifact_dir / "profile.json" with open(profile_path, "w", encoding="utf-8") as f: json.dump(summary, f, indent=2) ctx.client.log_artifact(profile_path, name="output_file") # -- log individual benchmark artifacts ---------------------------- for artifact_path in artifacts: ctx.client.log_artifact(artifact_path) return TFLiteProfilingResult.from_current_run(ctx) # -- AWS helpers ----------------------------------------------------------- def _profile_via_aws( ctx: HubContext, model_path: Path, device_name: str, benchmark_params: TFLiteBenchmarkParams | None = None, ) -> tuple[float, dict, list[dict], list[Path]]: """Submit a benchmark job to the Embedl device cloud and parse results. :param ctx: The execution context. :param model_path: Local path to the ``.tflite`` model. :param device_name: Device name for the Embedl device cloud. :param benchmark_params: Optional TFLite benchmark parameters. :return: Tuple of ``(latency_ms, summary_dict, execution_detail, artifact_paths)``. :raises ProfileError: If the benchmark job fails. """ try: job = ctx.client.submit_benchmark_job( model_path=model_path, device=device_name, benchmark_params=benchmark_params, ) except Exception as error: raise ProfileError( "Failed to submit benchmark job to the Embedl device cloud." ) from error try: result = job.download_results(artifacts_dir=ctx.artifact_dir) except Exception as error: raise ProfileError( "Failed to download benchmark results from the Embedl device cloud." ) from error summary = result.summary execution_detail = result.execution_detail latency_ms = summary.get("mean_ms") if latency_ms is None: raise ProfileError( "Benchmark result does not contain 'mean_ms' latency." ) latency_ms = float(latency_ms) artifacts = result.artifacts or [] return latency_ms, summary, execution_detail, artifacts