Source code for embedl_hub._internal.core.device.abc

# Copyright (C) 2026 Embedl AB

"""Core device abstractions.

This module defines the :class:`Device` dataclass and the protocols that
runners must implement (:class:`CommandRunner`, :class:`Connectable`).
"""

from __future__ import annotations

from collections.abc import Generator, Sequence
from contextlib import AbstractContextManager, contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Protocol, Self, TypeVar, runtime_checkable

from embedl_hub._internal.core.device.artifact_manager import (
    RemoteArtifactManager,
)
from embedl_hub._internal.core.device.config.abc import ProviderConfig
from embedl_hub._internal.core.device.spec import DeviceSpec
from embedl_hub._internal.core.types import LocalPath, RemotePath

if TYPE_CHECKING:
    from embedl_hub._internal.core.component.abc import Component
    from embedl_hub._internal.tracking.run_log import DeviceLog

_T = TypeVar("_T", bound=ProviderConfig)


[docs] class CommandResult(Protocol): """Protocol for the result of executing a command on a device.""" stdout: str | None stderr: str | None
[docs] class CommandRunner(Protocol): """Protocol for executing commands and transferring files on a device."""
[docs] def run(
self, command: Sequence[str], cwd: RemotePath | None = None, *, hide: bool = False, ) -> CommandResult: ...
[docs] def get(self, source: RemotePath, destination: LocalPath) -> None: ...
[docs] def put(self, source: LocalPath, destination: RemotePath) -> None: ...
@runtime_checkable class Connectable(Protocol): """Protocol for runners that require an explicit connection step. Runners implementing this protocol provide a :meth:`connect` context manager that establishes and tears down the underlying connection. The :attr:`is_active` property indicates whether a connection is currently open. """ @property def is_active(self) -> bool: """Whether a connection is currently established.""" ... def connect(self) -> AbstractContextManager[Self]: ...
[docs] @dataclass class Device: """Represents a target device for component execution. Directory and artifact management is delegated to an internal :class:`RemoteArtifactManager`. The most commonly used properties (``artifact_base_dir``, ``artifact_dir``, ``last_download_dir``) and the ``download_artifact_dir`` method are exposed as public delegates for convenience. :param name: A human-readable label for this device (e.g. ``'main'``, ``'jetson-orin'``). Used as the dictionary key in :attr:`HubContext.devices` and shown in console output. :param runner: Optional command runner for executing commands on the device. :param spec: The device specification (platform and environment). :param provider_type: The provider type identifier for this device. :param provider_config: Default :class:`ProviderConfig` used by all components unless overridden per-component. :param provider_config_overrides: Optional per-component overrides keyed by the concrete :class:`Component` subclass. """ name: str runner: CommandRunner | None spec: DeviceSpec provider_type: str provider_config: ProviderConfig | None = None provider_config_overrides: dict[type[Component], ProviderConfig] = field( default_factory=dict ) _artifacts: RemoteArtifactManager = field(init=False, repr=False) _in_context: bool = field(default=False, init=False, repr=False) def __post_init__(self) -> None: self._artifacts = RemoteArtifactManager(runner=self.runner) # Propagate the device name to the runner for display purposes. if self.runner is not None and hasattr(self.runner, "device_name"): self.runner.device_name = self.name @property def is_active(self) -> bool: """Whether this device is inside a :meth:`connect` context.""" return self._in_context # -- delegated properties ------------------------------------------------ @property def artifact_base_dir(self) -> RemotePath | None: """Get the remote base directory for artifacts.""" return self._artifacts.base_dir @artifact_base_dir.setter def artifact_base_dir(self, value: RemotePath | None) -> None: """Set the remote base directory for artifacts. :raises RuntimeError: If called while inside a device context. """ self._artifacts.base_dir = value @property def artifact_dir(self) -> RemotePath | None: """Get the artifact directory for the current run.""" return self._artifacts.artifact_dir @property def last_download_dir(self) -> LocalPath | None: """Get the local path of the most recent artifact download.""" return self._artifacts.last_download_dir # -- delegated methods ---------------------------------------------------
[docs] def download_artifact_dir(self, destination: LocalPath) -> LocalPath: """Download the current remote artifact directory to a local path. :param destination: The local directory to download into. :return: The local path where the artifacts were downloaded. :raises RuntimeError: If not inside a device context. :raises RuntimeError: If there is no command runner. :raises RuntimeError: If there is no remote artifact directory. """ return self._artifacts.download_artifact_dir(destination)
def _set_artifact_subdir_name(self, subdir_name: str) -> None: """Set the artifact sub-directory name for this device. **Protected** — called exclusively by the component system. End users should not call this method directly. Combined with :attr:`artifact_base_dir`, this determines the :attr:`artifact_dir` (``<base_dir>/<subdir_name>``). """ self._artifacts.set_subdir_name(subdir_name) def _clear_artifact_subdir_name(self) -> None: """Clear the artifact sub-directory name, resetting :attr:`artifact_dir`. **Protected** — called exclusively by the component system. End users should not call this method directly. """ self._artifacts.clear_subdir_name() # -- provider config ------------------------------------------------------
[docs] def get_provider_config( self, config_type: type[_T], component_cls: type[Component], ) -> _T | None: """Resolve the provider config for a specific component class. Walks the *component_cls* MRO looking for a matching override in :attr:`provider_config_overrides`. If none is found, falls back to :attr:`provider_config`. In both cases the config must be an instance of *config_type*; if it is not, ``None`` is returned. This method is generic on *config_type*: the return type is ``T | None`` where ``T`` is the concrete config class passed in, giving both runtime validation and static type narrowing:: cfg = dev.get_provider_config(TrtexecConfig, TensorRTCompiler) # type: TrtexecConfig | None :param config_type: The expected concrete :class:`ProviderConfig` subclass. :param component_cls: The component class requesting the config. :return: The resolved config, or ``None`` if no matching config is found. """ for cls in component_cls.__mro__: if cls in self.provider_config_overrides: cfg = self.provider_config_overrides[cls] if isinstance(cfg, config_type): return cfg return None if isinstance(self.provider_config, config_type): return self.provider_config return None
# -- device-specific logic -----------------------------------------------
[docs] def create_device_log(self, name: str) -> DeviceLog: """Create a :class:`~embedl_hub._internal.tracking.run_log.DeviceLog` snapshot of this device's current state. :param name: Logical name for this device in the run record. :return: A frozen :class:`~embedl_hub._internal.tracking.run_log.DeviceLog` capturing the current ``artifact_dir``, ``device_spec``, ``provider_type``, ``provider_config``, and ``downloaded_artifact_dir``. """ from embedl_hub._internal.tracking.run_log import DeviceLog return DeviceLog( name=name, artifact_dir=self.artifact_dir if self.is_active else None, spec=self.spec, provider_type=self.provider_type, provider_metadata={}, downloaded_artifact_dir=self._artifacts.last_download_dir, )
[docs] @contextmanager def connect(self) -> Generator[Self, None, None]: """Connect the device, establishing a runner connection if supported. :yields: This device instance with an active connection. """ if self.runner is not None and isinstance(self.runner, Connectable): runner_cm = self.runner.connect() runner_cm.__enter__() else: runner_cm = None try: self._in_context = True with self._artifacts.activate(): yield self finally: self._in_context = False if runner_cm is not None: runner_cm.__exit__(None, None, None)