# Copyright (C) 2026 Embedl AB
"""Lightweight function-level run tracking via :func:`hub_run`."""
from __future__ import annotations
import functools
import inspect
import warnings
from collections.abc import Callable
from datetime import datetime
from typing import TYPE_CHECKING
from embedl_hub._internal.core.component.dispatch import (
_cleanup_artifact_dirs,
_download_remote_artifacts,
_find_parent_run_id,
_log_context_tags,
_populate_run_log,
_save_run_yaml,
_setup_artifact_dirs,
_update_downloaded_artifact_dirs,
)
from embedl_hub._internal.core.component.output import ComponentOutput
from embedl_hub._internal.tracking.rest_api import RunType, coerce_run_type
from embedl_hub._internal.tracking.run_log import RunLog
if TYPE_CHECKING:
from embedl_hub._internal.core.context import HubContext
[docs]
def hub_run(run_type: RunType | str, *, name: str | None = None) -> Callable:
"""Decorator that wraps a plain function as a tracked Hub run.
Provides the same artifact-directory setup, tracking lifecycle,
remote-artifact download, and ``run.yaml`` saving as a full
:class:`~embedl_hub._internal.core.component.abc.Component`, without
requiring a class definition or provider registration. Device
narrowing is not performed — the function receives the full
:attr:`~embedl_hub._internal.core.context.HubContext.devices` mapping
as supplied by the caller.
The decorated function must accept ``ctx: HubContext`` as its first
positional parameter and must be called inside an active
``with ctx:`` block.
Usage::
@hub_run("analyse")
def analyse(ctx: HubContext, data: pd.DataFrame, *, threshold=0.5):
ctx.client.log_param("threshold", str(threshold))
result = run_model(data, threshold)
ctx.client.log_metric("accuracy", result.accuracy)
return result
with HubContext(project_name="my-project") as ctx:
output = analyse(ctx, data, threshold=0.8)
Parent run linking is automatic when a
:class:`~embedl_hub._internal.core.component.output.ComponentOutput`
carrying a ``run_log`` is passed as an argument.
:param run_type: The run type. Pass a :class:`RunType` member or a
plain string label (e.g. ``"analyse"``); strings are automatically
mapped to :attr:`RunType.CUSTOM`.
:param name: Display name logged to the Hub. Defaults to the
function's ``__name__``.
:returns: A decorator that wraps the target function.
:raises TypeError: If the decorated function does not have ``ctx`` as
its first parameter.
:raises RuntimeError: When the wrapped function is called outside an
active ``HubContext`` (i.e. before ``with ctx:`` is entered).
"""
def decorator(fn: Callable) -> Callable:
params = list(inspect.signature(fn).parameters.values())
if not params or params[0].name != "ctx":
raise TypeError(
f"'{fn.__name__}' must declare 'ctx: HubContext' as its "
f"first parameter to be used with @hub_run. "
f"Rename the first parameter to 'ctx'."
)
param_names = [p.name for p in params[1:]]
run_name = name or fn.__name__
# Normalize string run types at decoration time so that e.g.
# "compile" → RunType.COMPILE and "analyse" stays as "analyse".
_run_type: RunType | str = (
coerce_run_type(run_type)
if isinstance(run_type, str)
else run_type
)
_run_type_str = (
_run_type.value if isinstance(_run_type, RunType) else _run_type
)
@functools.wraps(fn)
def wrapper(
ctx: HubContext, *args: object, **kwargs: object
) -> object:
if not ctx.is_active:
raise RuntimeError(
f"'{fn.__name__}' was called outside an active "
f"HubContext. Enter the context first with 'with ctx:'."
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
subdir_name = f"{fn.__name__}_{timestamp}"
_setup_artifact_dirs(ctx, subdir_name)
try:
if ctx.client is not None:
parent_run_id = _find_parent_run_id(
param_names, args, kwargs
)
with ctx.client.start_run(
type=_run_type,
name=run_name,
parent_run_id=parent_run_id,
):
_populate_run_log(ctx)
_log_context_tags(ctx)
try:
result = fn(ctx, *args, **kwargs)
if not isinstance(
result, (RunLog, ComponentOutput)
):
warnings.warn(
f"'{fn.__name__}' did not return a RunLog or "
f"ComponentOutput. Downstream @hub_run calls "
f"will not be linked as child runs. "
f"Return 'ctx.run_log' to enable lineage.",
stacklevel=2,
)
except Exception:
_save_run_yaml(
ctx,
subdir_name,
run_type=_run_type_str,
component_type=fn.__name__,
status="FAILED",
parent_run_id=parent_run_id,
)
raise
_download_remote_artifacts(
ctx, subdir_name, ctx.client
)
_update_downloaded_artifact_dirs(ctx)
_save_run_yaml(
ctx,
subdir_name,
run_type=_run_type_str,
component_type=fn.__name__,
status="FINISHED",
parent_run_id=parent_run_id,
)
return result
else:
result = fn(ctx, *args, **kwargs)
_download_remote_artifacts(ctx, subdir_name)
return result
finally:
_cleanup_artifact_dirs(ctx)
return wrapper
return decorator