# Copyright (C) 2026 Embedl AB
"""SSH-based command runner for remote device execution.
This module provides an SSH command runner implementation using asyncssh
for executing commands and transferring files on remote devices.
"""
import asyncio
import io
import os
import warnings
from collections.abc import Generator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from typing import cast
import asyncssh
from rich.console import Console
from rich.text import Text
from typing_extensions import Self
from embedl_hub._internal.core.device.abc import (
CommandResult,
CommandRunner,
Connectable,
)
from embedl_hub._internal.core.types import LocalPath, RemotePath
def _ssh_agent_available() -> bool:
"""Return True if an SSH agent socket is present in the environment."""
return bool(os.environ.get("SSH_AUTH_SOCK"))
[docs]
class SSHConnectionError(Exception):
"""Raised when an SSH operation is attempted without an established connection."""
[docs]
@dataclass
class SSHConfig:
"""Configuration for SSH connections.
:param host: The hostname or IP address of the remote device.
:param port: The SSH port (default: 22).
:param username: The username for authentication.
:param password: Optional password for authentication. Using password
authentication is discouraged for security reasons and will raise a
warning. Prefer key-based authentication instead.
:param private_key_path: Path to a specific private key file. Overrides
all other authentication methods — the SSH agent and default key
discovery are both skipped. Only needed in unusual setups where a
non-standard key path must be forced.
:param known_hosts: Path to known_hosts file, or None to disable checking.
"""
host: str
username: str
port: int = 22
password: str | None = None
private_key_path: LocalPath | None = None
known_hosts: LocalPath | None = None
_console = Console(stderr=True)
"""Shared Rich console for SSH output (writes to stderr so it doesn't
mix with real program stdout when piped)."""
_BORDER_STYLE = "dim cyan"
_INDENT = " "
_PREFIX = Text(f"{_INDENT}│ ", style="dim cyan")
def _box_top(console: Console, title: str | None = None) -> None:
"""Print the top edge of a box: `` ┌─── title ───…``."""
width = console.width - len(_INDENT) - 1 # space for indent + ┌
if title:
label = f" {title} "
left = max((width - len(label)) // 2, 1)
right = max(width - left - len(label), 1)
line = f"{_INDENT}┌{'─' * left}{label}{'─' * right}"
else:
line = f"{_INDENT}┌{'─' * width}"
console.print(Text(line, style=_BORDER_STYLE), highlight=False)
def _box_bottom(console: Console) -> None:
"""Print the bottom edge of a box: `` └───…``."""
width = console.width - len(_INDENT) - 1
console.print(
Text(f"{_INDENT}└{'─' * width}", style=_BORDER_STYLE),
highlight=False,
)
class _TeeWriter(io.RawIOBase):
"""Captures bytes *and* prints each line to the console with a prefix.
asyncssh closes file objects passed as stdout/stderr, so this wrapper
prevents the real stream from being closed while also accumulating all
written bytes so they can be retrieved afterwards.
"""
def __init__(self, console: Console, style: str = "") -> None:
self._console = console
self._style = style
self._parts: list[bytes] = []
self._line_buf = ""
def write(self, b: bytes | bytearray) -> int: # type: ignore[override]
data = bytes(b)
self._parts.append(data)
text = self._line_buf + data.decode("utf-8", errors="replace")
lines = text.split("\n")
# Last element is the incomplete line (may be empty)
self._line_buf = lines.pop()
for line in lines:
self._console.print(
_PREFIX + Text(line, style=self._style),
highlight=False,
)
return len(data)
def flush_remaining(self) -> None:
"""Print any remaining partial line in the buffer."""
if self._line_buf:
self._console.print(
_PREFIX + Text(self._line_buf, style=self._style),
highlight=False,
)
self._line_buf = ""
@property
def captured(self) -> str:
"""Return all captured output decoded as UTF-8."""
return b"".join(self._parts).decode("utf-8", errors="replace")
def close(self) -> None:
# Intentionally do NOT close the underlying stream.
pass
[docs]
class SSHCommandRunner(CommandRunner, Connectable):
"""SSH-based command runner for remote device execution.
Implements the :class:`CommandRunner` and :class:`Connectable` protocols
using asyncssh for SSH connections. Call :meth:`connect` to obtain a
context manager that establishes the connection.
Example::
runner = SSHCommandRunner(config)
with runner.connect():
runner.run(["ls", "-la"])
"""
device_name: str | None
"""Human-readable device name, set by the owning
:class:`~embedl_hub._internal.core.device.abc.Device`."""
def __init__(self, config: SSHConfig) -> None:
"""Initialize the SSH command runner.
:param config: SSH connection configuration.
"""
if config.password is not None:
warnings.warn(
"Using password authentication is discouraged for security reasons. "
"Consider using key-based authentication instead.",
UserWarning,
stacklevel=2,
)
self._config = config
self._connection: asyncssh.SSHClientConnection | None = None
self._loop: asyncio.AbstractEventLoop | None = None
self.device_name: str | None = None
@property
def is_active(self) -> bool:
"""Whether an SSH connection is currently established."""
return self._connection is not None
async def _connect(self) -> asyncssh.SSHClientConnection:
"""Establish an SSH connection to the remote device.
:return: An active SSH client connection.
"""
connect_kwargs: dict[str, object] = {
"host": self._config.host,
"port": self._config.port,
"username": self._config.username,
}
if self._config.password is not None:
connect_kwargs["password"] = self._config.password
if self._config.private_key_path is not None:
connect_kwargs["client_keys"] = [
str(self._config.private_key_path)
]
elif _ssh_agent_available():
connect_kwargs["agent_path"] = os.environ.get("SSH_AUTH_SOCK")
connect_kwargs["client_keys"] = []
if self._config.known_hosts is None:
connect_kwargs["known_hosts"] = None
else:
connect_kwargs["known_hosts"] = str(self._config.known_hosts)
return await asyncssh.connect(**connect_kwargs)
def _ensure_connected(self) -> asyncssh.SSHClientConnection:
"""Ensure a connection is established and return it.
:return: The active SSH connection.
:raises SSHConnectionError: If no connection is established.
"""
if self._connection is None:
raise SSHConnectionError(
"No SSH connection established. "
"Call runner.connect() first: `with runner.connect(): ...`"
)
return self._connection
def _ensure_loop(self) -> asyncio.AbstractEventLoop:
"""Return the active event loop.
:return: The event loop created during :meth:`connect`.
:raises SSHConnectionError: If no connection (and thus no loop) exists.
"""
if self._loop is None:
raise SSHConnectionError(
"No SSH connection established. "
"Call runner.connect() first: `with runner.connect(): ...`"
)
return self._loop
[docs]
@contextmanager
def connect(self) -> Generator[Self, None, None]:
"""Establish an SSH connection for the duration of the context.
:yields: This runner instance with an active connection.
:raises RuntimeError: If a connection is already active.
:raises SSHConnectionError: If the connection or authentication fails.
"""
if self._connection is not None:
raise RuntimeError(
"SSHCommandRunner already has an active connection. "
"Cannot open a new connection while one is active."
)
# Validate private key path before allocating any resources.
if self._config.private_key_path is not None:
if not os.path.exists(str(self._config.private_key_path)):
raise SSHConnectionError(
f"Private key file not found: '{self._config.private_key_path}'. "
"Check the path passed to SSHConfig(private_key_path=...)."
)
loop = asyncio.new_event_loop()
self._loop = loop
def _cleanup_loop() -> None:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
self._loop = None
def _target_str() -> str:
if self.device_name:
return (
f"device '{self.device_name}' "
f"({self._config.username}@{self._config.host}:{self._config.port})"
)
return f"{self._config.username}@{self._config.host}:{self._config.port}"
try:
self._connection = loop.run_until_complete(self._connect())
except OSError as exc:
_cleanup_loop()
raise SSHConnectionError(
f"Could not connect to {_target_str()}. "
f"Check that the device is reachable and the SSH server is running."
) from exc
except asyncssh.KeyImportError as exc:
_cleanup_loop()
raise SSHConnectionError(
"Could not load SSH private keys. "
"If you use a hardware security key (YubiKey or similar), "
"make sure your SSH agent is running and SSH_AUTH_SOCK is set "
"in the environment where you run this script."
) from exc
except asyncssh.PermissionDenied as exc:
_cleanup_loop()
raise SSHConnectionError(
f"SSH authentication failed for {_target_str()}. "
"Check that the username is correct and the key is authorised "
"on the device (see ~/.ssh/authorized_keys)."
) from exc
try:
yield self
finally:
if self._connection is not None:
self._connection.close()
loop.run_until_complete(self._connection.wait_closed())
self._connection = None
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
self._loop = None
async def _run(
self,
command: Sequence[str],
cwd: RemotePath | None = None,
*,
hide: bool = False,
) -> CommandResult:
conn = self._ensure_connected()
cmd_str = " ".join(map(str, command))
if cwd is not None:
cmd_str = f"cd {cwd} && {cmd_str}"
host = self._config.host
user = self._config.username
name = self.device_name
console = _console
addr = f"{user}@{host}"
if hide:
# Compact one-liner: no box, no streaming output
console.print()
header = (
f"Running command on '{name}' ({addr}):"
if name is not None
else f"Running command on {addr}:"
)
console.print(header, highlight=False)
console.print(
Text(f"{_INDENT}└ ", style=_BORDER_STYLE)
+ Text("$ ", style="bold green")
+ Text(cmd_str, style="bold"),
highlight=False,
)
result = await conn.run(cmd_str, check=True)
return cast(CommandResult, result)
# Box with streamed output
console.print()
header = (
f"Running command on '{name}' ({addr}):"
if name is not None
else f"Running command on {addr}:"
)
console.print(header, highlight=False)
_box_top(console, title=name or addr)
console.print(
_PREFIX
+ Text("$ ", style="bold green")
+ Text(cmd_str, style="bold"),
highlight=False,
)
stdout_tee = _TeeWriter(console)
stderr_tee = _TeeWriter(console, style="red")
try:
result = await conn.run(
cmd_str,
check=True,
stdout=stdout_tee,
stderr=stderr_tee,
)
except asyncssh.ProcessError as exc:
stdout_tee.flush_remaining()
stderr_tee.flush_remaining()
_box_bottom(console)
# asyncssh.ProcessError.stdout/stderr may be empty when output
# was streamed to our TeeWriter. Patch the captured text back
# onto the exception so callers (e.g. run_remote_command) can
# surface it in their error messages.
exc.stdout = stdout_tee.captured
exc.stderr = stderr_tee.captured
raise
except Exception:
stdout_tee.flush_remaining()
stderr_tee.flush_remaining()
_box_bottom(console)
raise
stdout_tee.flush_remaining()
stderr_tee.flush_remaining()
_box_bottom(console)
result.stdout = stdout_tee.captured
result.stderr = stderr_tee.captured
return cast(CommandResult, result)
async def _get(self, source: RemotePath, destination: LocalPath) -> None:
conn = self._ensure_connected()
await asyncssh.scp((conn, str(source)), str(destination))
async def _put(self, source: LocalPath, destination: RemotePath) -> None:
conn = self._ensure_connected()
await asyncssh.scp(str(source), (conn, str(destination)))
[docs]
def run(
self,
command: Sequence[str],
cwd: RemotePath | None = None,
*,
hide: bool = False,
) -> CommandResult:
"""Execute a command on the remote device.
:param command: The command and arguments to execute.
:param cwd: Optional working directory for command execution.
:param hide: If ``True``, print only a compact one-liner instead
of the full bordered box with streamed output. Use this for
short utility commands (e.g. ``mkdir``, ``rm``) that don't
produce meaningful output.
:return: The completed process result.
:raises SSHConnectionError: If no connection is established.
:raises asyncssh.ProcessError: If the command fails.
"""
return self._ensure_loop().run_until_complete(
self._run(command, cwd, hide=hide)
)
[docs]
def get(self, source: RemotePath, destination: LocalPath) -> None:
"""Download a file from the remote device.
:param source: The remote file path to download.
:param destination: The local destination path.
:raises SSHConnectionError: If no connection is established.
"""
self._ensure_loop().run_until_complete(self._get(source, destination))
[docs]
def put(self, source: LocalPath, destination: RemotePath) -> None:
"""Upload a file to the remote device.
:param source: The local file path to upload.
:param destination: The remote destination path.
:raises SSHConnectionError: If no connection is established.
"""
self._ensure_loop().run_until_complete(self._put(source, destination))