Source code for embedl_hub._internal.core.device.manager

# Copyright (C) 2026 Embedl AB

"""Device manager for creating and managing device instances."""

from __future__ import annotations

from typing import TYPE_CHECKING

from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.device.abc import Device
from embedl_hub._internal.core.device.config.embedl_onnxruntime import (
    EmbedlONNXRuntimeConfig,
)
from embedl_hub._internal.core.device.config.trtexec import TrtexecConfig
from embedl_hub._internal.core.device.spec import DeviceSpec
from embedl_hub._internal.core.device.ssh import SSHCommandRunner, SSHConfig

if TYPE_CHECKING:
    from embedl_hub._internal.core.component.abc import Component


[docs] class DeviceManager: """Manager for creating device instances."""
[docs] @classmethod def get_ssh_device( cls, config: SSHConfig, *, name: str = "main", provider_type: str = ProviderType.EMBEDL_ONNXRUNTIME, ) -> Device: """Create a device with an SSH command runner. :param config: SSH connection configuration. :param name: Human-readable label for this device. :param provider_type: The provider type identifier (e.g. ``ProviderType.EMBEDL_ONNXRUNTIME``, ``ProviderType.TRTEXEC``). :return: A Device instance configured for SSH access. """ runner = SSHCommandRunner(config) return Device( name=name, runner=runner, spec=DeviceSpec(), provider_type=provider_type, )
[docs] @classmethod def get_qai_hub_device( cls, device_name: str, *, name: str = "main", ) -> Device: """Create a device for Qualcomm AI Hub. :param device_name: The QAI Hub device name (e.g., 'Samsung Galaxy S24'). :param name: Human-readable label for this device. :return: A Device instance configured for QAI Hub. """ return Device( name=name, runner=None, spec=DeviceSpec(device_name=device_name), provider_type=ProviderType.QAI_HUB, )
[docs] @classmethod def get_embedl_onnxruntime_device( cls, config: SSHConfig, *, name: str = "main", provider_config: EmbedlONNXRuntimeConfig = EmbedlONNXRuntimeConfig(), overrides: dict[type[Component], EmbedlONNXRuntimeConfig] | None = None, ) -> Device: """Create an SSH device for ONNX Runtime compilation. :param config: SSH connection configuration. :param name: Human-readable label for this device. :param provider_config: Typed configuration for the ``embedl-onnxruntime`` CLI provider. :param overrides: Optional per-component config overrides keyed by concrete component class. :return: A Device instance configured for remote ORT compilation. """ runner = SSHCommandRunner(config) return Device( name=name, runner=runner, spec=DeviceSpec(), provider_type=ProviderType.EMBEDL_ONNXRUNTIME, provider_config=provider_config, provider_config_overrides=dict(overrides) if overrides else {}, )
[docs] @classmethod def get_tensorrt_device( cls, config: SSHConfig, *, name: str = "main", provider_config: TrtexecConfig = TrtexecConfig(), overrides: dict[type[Component], TrtexecConfig] | None = None, ) -> Device: """Create an SSH device for TensorRT compilation. :param config: SSH connection configuration. :param name: Human-readable label for this device. :param provider_config: Typed configuration for the ``trtexec`` CLI provider. :param overrides: Optional per-component config overrides keyed by concrete component class. :return: A Device instance configured for remote TensorRT compilation. """ runner = SSHCommandRunner(config) return Device( name=name, runner=runner, spec=DeviceSpec(), provider_type=ProviderType.TRTEXEC, provider_config=provider_config, provider_config_overrides=dict(overrides) if overrides else {}, )
[docs] @classmethod def get_aws_device( cls, device_name: str, *, name: str = "main", ) -> Device: """Create a device for the Embedl device cloud (AWS Device Farm). :param device_name: The Embedl device cloud device name. :param name: Human-readable label for this device. :return: A Device instance configured for AWS Device Farm. """ return Device( name=name, runner=None, spec=DeviceSpec(device_name=device_name), provider_type=ProviderType.AWS, )