Source code for embedl_hub._internal.core.context

# Copyright (C) 2026 Embedl AB

"""Context for component execution."""

from __future__ import annotations

import shutil
import tempfile
from collections.abc import Sequence
from contextlib import ExitStack
from dataclasses import dataclass
from types import TracebackType
from typing import TYPE_CHECKING, Self

from embedl_hub._internal.core.device.abc import Device
from embedl_hub._internal.core.types import LocalPath

if TYPE_CHECKING:
    from embedl_hub._internal.tracking.client import Client


[docs] @dataclass class HubContext: """Context for component execution.""" devices: dict[str, Device] client: Client project_name: str artifact_base_dir: LocalPath _artifact_dir: LocalPath | None download_artifacts: bool auto_connect: bool _active: bool _tags: dict[str, str] def __init__( self, project_name: str, artifact_base_dir: LocalPath | None = None, devices: Sequence[Device] | None = None, download_artifacts: bool = False, auto_connect: bool = True, log_remote_artifacts: bool = True, ) -> None: """Initialize the context. :param project_name: The project name used for tracking on the Embedl Hub. :param artifact_base_dir: Local base directory for storing run artifacts. If not provided, a temporary directory is created when entering the context manager and deleted on exit. :param devices: Devices to register. Stored internally as a dict keyed by each device's name. :param download_artifacts: Whether to download remote artifact directories after each component run. .. warning:: Enabling this option may significantly increase the total execution time of each component run, as all remote artifacts are transferred over the network before the run completes. :param auto_connect: Whether to automatically connect all devices when entering the context manager. Defaults to ``True``. :param log_remote_artifacts: Whether to fetch remote artifacts from devices and upload them to the Hub. Defaults to ``True``. Set to ``False`` to speed up execution. :raises ValueError: If `devices` contains duplicate names. """ self.project_name = project_name self._provided_artifact_base_dir = artifact_base_dir self.artifact_base_dir = artifact_base_dir or LocalPath() self._artifact_dir = None self.download_artifacts = download_artifacts self.auto_connect = auto_connect self._active = False self._owns_artifact_base_dir = False self._exit_stack: ExitStack | None = None self._tags: dict[str, str] = {} # Convert the device list into a dict keyed by Device.name. if devices is None: self.devices = {} else: self.devices = {} for dev in devices: if dev.name in self.devices: raise ValueError( f"Duplicate device name '{dev.name}'. " f"Each device must have a unique name." ) self.devices[dev.name] = dev # Import here to avoid circular import from embedl_hub._internal.tracking.client import Client self.client = Client(log_remote_artifacts=log_remote_artifacts) self.client.set_project(project_name) @property def is_active(self) -> bool: """Whether this context has been entered and is currently active.""" return self._active @property def artifact_dir(self) -> LocalPath: """Get the local artifact directory for the current component run. :raises RuntimeError: If no artifact directory has been set. This typically means the property is accessed outside of a component dispatch. """ if self._artifact_dir is None: raise RuntimeError( "No local artifact directory available. " "'artifact_dir' is only available during a component run." ) return self._artifact_dir @property def project_dir(self) -> LocalPath: """Get the project-scoped artifact directory. Returns ``artifact_base_dir / project_name``. Run subdirectories are created beneath this path. """ return self.artifact_base_dir / self.project_name @property def tags(self) -> dict[str, str]: """Get the current tags set on this context (read-only copy).""" return dict(self._tags)
[docs] def set_tags(self, **tags: str) -> None: """Set tags that will be applied to all subsequent runs. Replaces any previously set tags. Each keyword argument becomes a tag with the argument name as the tag name and its value as the tag value. Example:: ctx.set_tags(model="resnet50", dataset="imagenet") :param tags: Arbitrary keyword arguments where each key is a tag name and each value is a tag value (both ``str``). :raises TypeError: If any value is not a string. """ for key, value in tags.items(): if not isinstance(value, str): raise TypeError( f"Tag value for '{key}' must be a str, " f"got {type(value).__name__}" ) self._tags = dict(tags)
def __enter__(self) -> Self: """Enter the context, optionally connecting all devices.""" if self._provided_artifact_base_dir is None: tmp = LocalPath(tempfile.mkdtemp(prefix="embedl-")) self.artifact_base_dir = tmp self._owns_artifact_base_dir = True self.project_dir.mkdir(parents=True, exist_ok=True) if self.auto_connect: stack = ExitStack() for dev in self.devices.values(): stack.enter_context(dev.connect()) self._exit_stack = stack self._active = True return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: """Exit the context, disconnecting all devices.""" self._active = False if self._exit_stack is not None: self._exit_stack.__exit__(exc_type, exc_val, exc_tb) self._exit_stack = None if self._owns_artifact_base_dir: shutil.rmtree(self.artifact_base_dir, ignore_errors=True) self._owns_artifact_base_dir = False