Source code for embedl_hub._internal.core.context
# Copyright (C) 2026 Embedl AB
"""Context for component execution."""
from __future__ import annotations
import shutil
import tempfile
from collections.abc import Sequence
from contextlib import ExitStack
from dataclasses import dataclass
from types import TracebackType
from typing import TYPE_CHECKING, Self
from embedl_hub._internal.core.device.abc import Device
from embedl_hub._internal.core.types import LocalPath
if TYPE_CHECKING:
from embedl_hub._internal.tracking.client import Client
[docs]
@dataclass
class HubContext:
"""Context for component execution."""
devices: dict[str, Device]
client: Client
project_name: str
artifact_base_dir: LocalPath
_artifact_dir: LocalPath | None
download_artifacts: bool
auto_connect: bool
_active: bool
_tags: dict[str, str]
def __init__(
self,
project_name: str,
artifact_base_dir: LocalPath | None = None,
devices: Sequence[Device] | None = None,
download_artifacts: bool = False,
auto_connect: bool = True,
log_remote_artifacts: bool = True,
) -> None:
"""Initialize the context.
:param project_name: The project name used for tracking on the
Embedl Hub.
:param artifact_base_dir: Local base directory for storing run
artifacts. If not provided, a temporary directory is
created when entering the context manager and deleted on
exit.
:param devices: Devices to register. Stored internally as a
dict keyed by each device's name.
:param download_artifacts: Whether to download remote artifact
directories after each component run.
.. warning::
Enabling this option may significantly increase the
total execution time of each component run, as all
remote artifacts are transferred over the network
before the run completes.
:param auto_connect: Whether to automatically connect all
devices when entering the context manager. Defaults to
``True``.
:param log_remote_artifacts: Whether to fetch remote artifacts
from devices and upload them to the Hub. Defaults to
``True``. Set to ``False`` to speed up execution.
:raises ValueError: If `devices` contains duplicate names.
"""
self.project_name = project_name
self._provided_artifact_base_dir = artifact_base_dir
self.artifact_base_dir = artifact_base_dir or LocalPath()
self._artifact_dir = None
self.download_artifacts = download_artifacts
self.auto_connect = auto_connect
self._active = False
self._owns_artifact_base_dir = False
self._exit_stack: ExitStack | None = None
self._tags: dict[str, str] = {}
# Convert the device list into a dict keyed by Device.name.
if devices is None:
self.devices = {}
else:
self.devices = {}
for dev in devices:
if dev.name in self.devices:
raise ValueError(
f"Duplicate device name '{dev.name}'. "
f"Each device must have a unique name."
)
self.devices[dev.name] = dev
# Import here to avoid circular import
from embedl_hub._internal.tracking.client import Client
self.client = Client(log_remote_artifacts=log_remote_artifacts)
self.client.set_project(project_name)
@property
def is_active(self) -> bool:
"""Whether this context has been entered and is currently active."""
return self._active
@property
def artifact_dir(self) -> LocalPath:
"""Get the local artifact directory for the current component run.
:raises RuntimeError: If no artifact directory has been set.
This typically means the property is accessed outside of a
component dispatch.
"""
if self._artifact_dir is None:
raise RuntimeError(
"No local artifact directory available. "
"'artifact_dir' is only available during a component run."
)
return self._artifact_dir
@property
def project_dir(self) -> LocalPath:
"""Get the project-scoped artifact directory.
Returns ``artifact_base_dir / project_name``. Run
subdirectories are created beneath this path.
"""
return self.artifact_base_dir / self.project_name
@property
def tags(self) -> dict[str, str]:
"""Get the current tags set on this context (read-only copy)."""
return dict(self._tags)
[docs]
def set_tags(self, **tags: str) -> None:
"""Set tags that will be applied to all subsequent runs.
Replaces any previously set tags. Each keyword argument becomes
a tag with the argument name as the tag name and its value as the
tag value.
Example::
ctx.set_tags(model="resnet50", dataset="imagenet")
:param tags: Arbitrary keyword arguments where each key is a tag
name and each value is a tag value (both ``str``).
:raises TypeError: If any value is not a string.
"""
for key, value in tags.items():
if not isinstance(value, str):
raise TypeError(
f"Tag value for '{key}' must be a str, "
f"got {type(value).__name__}"
)
self._tags = dict(tags)
def __enter__(self) -> Self:
"""Enter the context, optionally connecting all devices."""
if self._provided_artifact_base_dir is None:
tmp = LocalPath(tempfile.mkdtemp(prefix="embedl-"))
self.artifact_base_dir = tmp
self._owns_artifact_base_dir = True
self.project_dir.mkdir(parents=True, exist_ok=True)
if self.auto_connect:
stack = ExitStack()
for dev in self.devices.values():
stack.enter_context(dev.connect())
self._exit_stack = stack
self._active = True
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the context, disconnecting all devices."""
self._active = False
if self._exit_stack is not None:
self._exit_stack.__exit__(exc_type, exc_val, exc_tb)
self._exit_stack = None
if self._owns_artifact_base_dir:
shutil.rmtree(self.artifact_base_dir, ignore_errors=True)
self._owns_artifact_base_dir = False