This commit is contained in:
144
utils/ssh_helper.py
Normal file
144
utils/ssh_helper.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Shared SSH command helpers for AutoHeal and AiderHeal.
|
||||
|
||||
The service layer owns allowlists and action semantics; this module only
|
||||
builds and runs the SSH command consistently.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
|
||||
|
||||
RemoteCommand = Union[str, Sequence[Any]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SshExecResult:
|
||||
returncode: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
argv: List[str]
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return self.returncode == 0
|
||||
|
||||
|
||||
def ensure_ssh_key_permissions(key_path: Optional[str], logger: Optional[Any] = None) -> None:
|
||||
if not key_path:
|
||||
return
|
||||
safe_key = os.path.expanduser(key_path)
|
||||
if not os.path.exists(safe_key):
|
||||
if logger:
|
||||
logger.warning("SSH key not found: %s", safe_key)
|
||||
return
|
||||
try:
|
||||
os.chmod(safe_key, 0o600)
|
||||
except Exception as exc:
|
||||
if logger:
|
||||
logger.warning("Failed to secure SSH key: %s", exc)
|
||||
|
||||
|
||||
def build_ssh_command(
|
||||
*,
|
||||
host: str,
|
||||
user: str,
|
||||
command: RemoteCommand,
|
||||
port: int = 22,
|
||||
key_path: Optional[str] = None,
|
||||
connect_timeout: int = 10,
|
||||
jump_host: Optional[str] = None,
|
||||
jump_user: Optional[str] = None,
|
||||
strict_host_key_checking: str = "no",
|
||||
batch_mode: bool = False,
|
||||
server_alive_interval: Optional[int] = None,
|
||||
server_alive_count_max: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
argv = [
|
||||
"ssh",
|
||||
"-p",
|
||||
str(port),
|
||||
]
|
||||
if key_path:
|
||||
argv.extend(["-i", os.path.expanduser(key_path)])
|
||||
argv.extend(["-o", f"StrictHostKeyChecking={strict_host_key_checking}"])
|
||||
if batch_mode:
|
||||
argv.extend(["-o", "BatchMode=yes"])
|
||||
argv.extend(["-o", f"ConnectTimeout={connect_timeout}"])
|
||||
if server_alive_interval is not None:
|
||||
argv.extend(["-o", f"ServerAliveInterval={server_alive_interval}"])
|
||||
if server_alive_count_max is not None:
|
||||
argv.extend(["-o", f"ServerAliveCountMax={server_alive_count_max}"])
|
||||
if jump_host and jump_user:
|
||||
argv.extend(["-J", f"{jump_user}@{jump_host}"])
|
||||
argv.append(f"{user}@{host}")
|
||||
|
||||
if isinstance(command, str):
|
||||
argv.append(command)
|
||||
else:
|
||||
argv.append("--")
|
||||
argv.extend(str(part) for part in command)
|
||||
return argv
|
||||
|
||||
|
||||
def run_ssh_command(
|
||||
*,
|
||||
host: str,
|
||||
user: str,
|
||||
command: RemoteCommand,
|
||||
port: int = 22,
|
||||
key_path: Optional[str] = None,
|
||||
connect_timeout: int = 10,
|
||||
command_timeout: int = 60,
|
||||
jump_host: Optional[str] = None,
|
||||
jump_user: Optional[str] = None,
|
||||
strict_host_key_checking: str = "no",
|
||||
batch_mode: bool = False,
|
||||
server_alive_interval: Optional[int] = None,
|
||||
server_alive_count_max: Optional[int] = None,
|
||||
cwd: Optional[str] = None,
|
||||
logger: Optional[Any] = None,
|
||||
) -> SshExecResult:
|
||||
ensure_ssh_key_permissions(key_path, logger=logger)
|
||||
argv = build_ssh_command(
|
||||
host=host,
|
||||
user=user,
|
||||
command=command,
|
||||
port=port,
|
||||
key_path=key_path,
|
||||
connect_timeout=connect_timeout,
|
||||
jump_host=jump_host,
|
||||
jump_user=jump_user,
|
||||
strict_host_key_checking=strict_host_key_checking,
|
||||
batch_mode=batch_mode,
|
||||
server_alive_interval=server_alive_interval,
|
||||
server_alive_count_max=server_alive_count_max,
|
||||
)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
argv,
|
||||
shell=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=cwd,
|
||||
timeout=command_timeout,
|
||||
)
|
||||
return SshExecResult(
|
||||
returncode=result.returncode,
|
||||
stdout=result.stdout.strip(),
|
||||
stderr=result.stderr.strip(),
|
||||
argv=argv,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
return SshExecResult(
|
||||
returncode=-1,
|
||||
stdout="",
|
||||
stderr=f"SSH timeout after {command_timeout}s",
|
||||
argv=argv,
|
||||
)
|
||||
except Exception as exc:
|
||||
if logger:
|
||||
logger.warning("SSH exec error: %s", exc)
|
||||
return SshExecResult(returncode=-1, stdout="", stderr=str(exc), argv=argv)
|
||||
Reference in New Issue
Block a user