- Command intake (reboot/shutdown) on infoscreen/{uuid}/commands with ack lifecycle
- MQTT_USER/MQTT_PASSWORD_BROKER split from identity vars; configure_mqtt_security() updated
- infoscreen-simclient.service: Type=notify, WatchdogSec=60, Restart=on-failure
- infoscreen-notify-failure@.service + script: retained MQTT alert when systemd gives up (Gap 3)
- _sd_notify() watchdog keepalive in simclient main loop (Gap 1)
- broker_connection block in health payload: reconnect_count, last_disconnect_at (Gap 2)
- COMMAND_MOCK_REBOOT_IMMEDIATE_COMPLETE canary flag with safety guard
- SERVER_TEAM_ACTIONS.md: server-side integration action items
- Docs: README, CHANGELOG, src/README, copilot-instructions updated
- 43 tests passing
288 lines
10 KiB
Python
288 lines
10 KiB
Python
"""
|
|
Unit tests for reboot/shutdown command intake primitives.
|
|
|
|
Run from project root (venv activated):
|
|
python -m pytest tests/test_command_intake.py -v
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import tempfile
|
|
import unittest
|
|
from datetime import datetime, timezone, timedelta
|
|
from unittest.mock import patch
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
|
|
|
from simclient import ( # noqa: E402
|
|
NIL_COMMAND_ID,
|
|
command_requires_recovery_completion,
|
|
command_mock_reboot_immediate_complete_enabled,
|
|
configure_mqtt_security,
|
|
mqtt,
|
|
validate_command_payload,
|
|
publish_command_ack,
|
|
_prune_processed_commands,
|
|
load_processed_commands,
|
|
persist_processed_commands,
|
|
)
|
|
|
|
|
|
class FakePublishResult:
|
|
def __init__(self, rc):
|
|
self.rc = rc
|
|
|
|
|
|
class FakeMqttClient:
|
|
def __init__(self, rc=0):
|
|
self.rc = rc
|
|
self.calls = []
|
|
|
|
def publish(self, topic, payload, qos=0, retain=False):
|
|
self.calls.append({
|
|
"topic": topic,
|
|
"payload": payload,
|
|
"qos": qos,
|
|
"retain": retain,
|
|
})
|
|
return FakePublishResult(self.rc)
|
|
|
|
|
|
class SequencedMqttClient:
|
|
def __init__(self, rc_sequence):
|
|
self._rc_sequence = list(rc_sequence)
|
|
self.calls = []
|
|
|
|
def publish(self, topic, payload, qos=0, retain=False):
|
|
rc = self._rc_sequence.pop(0) if self._rc_sequence else 0
|
|
self.calls.append({
|
|
"topic": topic,
|
|
"payload": payload,
|
|
"qos": qos,
|
|
"retain": retain,
|
|
"rc": rc,
|
|
})
|
|
return FakePublishResult(rc)
|
|
|
|
|
|
class FakeSecurityClient:
|
|
def __init__(self):
|
|
self.username = None
|
|
self.password = None
|
|
self.tls_kwargs = None
|
|
self.tls_insecure = None
|
|
|
|
def username_pw_set(self, username, password=None):
|
|
self.username = username
|
|
self.password = password
|
|
|
|
def tls_set(self, **kwargs):
|
|
self.tls_kwargs = kwargs
|
|
|
|
def tls_insecure_set(self, enabled):
|
|
self.tls_insecure = enabled
|
|
|
|
|
|
def _valid_payload(seconds_valid=240):
|
|
now = datetime.now(timezone.utc)
|
|
exp = now + timedelta(seconds=seconds_valid)
|
|
return {
|
|
"schema_version": "1.0",
|
|
"command_id": "5d1f8b4b-7e85-44fb-8f38-3f5d5da5e2e4",
|
|
"client_uuid": "9b8d1856-ff34-4864-a726-12de072d0f77",
|
|
"action": "reboot_host",
|
|
"issued_at": now.strftime("%Y-%m-%dT%H:%M:%SZ"),
|
|
"expires_at": exp.strftime("%Y-%m-%dT%H:%M:%SZ"),
|
|
"requested_by": 1,
|
|
"reason": "operator_request",
|
|
}
|
|
|
|
|
|
class TestValidateCommandPayload(unittest.TestCase):
|
|
def test_accepts_valid_payload(self):
|
|
payload = _valid_payload()
|
|
ok, normalized, code, msg = validate_command_payload(payload, payload["client_uuid"])
|
|
self.assertTrue(ok)
|
|
self.assertIsNone(code)
|
|
self.assertIsNone(msg)
|
|
self.assertEqual(normalized["action"], "reboot_host")
|
|
|
|
def test_rejects_extra_fields(self):
|
|
payload = _valid_payload()
|
|
payload["extra"] = "x"
|
|
ok, _, code, msg = validate_command_payload(payload, payload["client_uuid"])
|
|
self.assertFalse(ok)
|
|
self.assertEqual(code, "invalid_schema")
|
|
self.assertIn("unexpected fields", msg)
|
|
|
|
def test_rejects_stale_command(self):
|
|
payload = _valid_payload()
|
|
old_issued = datetime.now(timezone.utc) - timedelta(hours=3)
|
|
old_expires = datetime.now(timezone.utc) - timedelta(hours=2)
|
|
payload["issued_at"] = old_issued.strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
payload["expires_at"] = old_expires.strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
ok, _, code, _ = validate_command_payload(payload, payload["client_uuid"])
|
|
self.assertFalse(ok)
|
|
self.assertEqual(code, "stale_command")
|
|
|
|
def test_rejects_action_outside_enum(self):
|
|
payload = _valid_payload()
|
|
payload["action"] = "restart_service"
|
|
ok, _, code, msg = validate_command_payload(payload, payload["client_uuid"])
|
|
self.assertFalse(ok)
|
|
self.assertEqual(code, "invalid_schema")
|
|
self.assertIn("action must be one of", msg)
|
|
|
|
def test_rejects_client_uuid_mismatch(self):
|
|
payload = _valid_payload()
|
|
ok, _, code, msg = validate_command_payload(
|
|
payload,
|
|
"aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
|
|
)
|
|
self.assertFalse(ok)
|
|
self.assertEqual(code, "invalid_schema")
|
|
self.assertIn("client_uuid", msg)
|
|
|
|
|
|
class TestCommandLifecyclePolicy(unittest.TestCase):
|
|
def test_reboot_requires_recovery_completion(self):
|
|
self.assertTrue(command_requires_recovery_completion("reboot_host"))
|
|
self.assertFalse(command_requires_recovery_completion("shutdown_host"))
|
|
|
|
def test_mock_reboot_immediate_completion_enabled_for_mock_helper(self):
|
|
with patch("simclient.COMMAND_MOCK_REBOOT_IMMEDIATE_COMPLETE", True), \
|
|
patch("simclient.COMMAND_HELPER_PATH", "/home/pi/scripts/mock-command-helper.sh"):
|
|
self.assertTrue(command_mock_reboot_immediate_complete_enabled("reboot_host"))
|
|
|
|
def test_mock_reboot_immediate_completion_disabled_for_live_helper(self):
|
|
with patch("simclient.COMMAND_MOCK_REBOOT_IMMEDIATE_COMPLETE", True), \
|
|
patch("simclient.COMMAND_HELPER_PATH", "/usr/local/bin/infoscreen-cmd-helper.sh"):
|
|
self.assertFalse(command_mock_reboot_immediate_complete_enabled("reboot_host"))
|
|
|
|
|
|
class TestMqttSecurityConfiguration(unittest.TestCase):
|
|
def test_configure_username_password(self):
|
|
fake_client = FakeSecurityClient()
|
|
with patch("simclient.MQTT_USER", ""), \
|
|
patch("simclient.MQTT_PASSWORD_BROKER", ""), \
|
|
patch("simclient.MQTT_USERNAME", "client-user"), \
|
|
patch("simclient.MQTT_PASSWORD", "client-pass"), \
|
|
patch("simclient.MQTT_TLS_ENABLED", False):
|
|
configured = configure_mqtt_security(fake_client)
|
|
|
|
self.assertEqual(fake_client.username, "client-user")
|
|
self.assertEqual(fake_client.password, "client-pass")
|
|
self.assertFalse(configured["tls"])
|
|
|
|
def test_configure_tls(self):
|
|
fake_client = FakeSecurityClient()
|
|
with patch("simclient.MQTT_USER", ""), \
|
|
patch("simclient.MQTT_PASSWORD_BROKER", ""), \
|
|
patch("simclient.MQTT_USERNAME", ""), \
|
|
patch("simclient.MQTT_PASSWORD", ""), \
|
|
patch("simclient.MQTT_TLS_ENABLED", True), \
|
|
patch("simclient.MQTT_TLS_CA_CERT", "/tmp/ca.pem"), \
|
|
patch("simclient.MQTT_TLS_CERT", "/tmp/client.pem"), \
|
|
patch("simclient.MQTT_TLS_KEY", "/tmp/client.key"), \
|
|
patch("simclient.MQTT_TLS_INSECURE", True):
|
|
configured = configure_mqtt_security(fake_client)
|
|
|
|
self.assertTrue(configured["tls"])
|
|
self.assertEqual(fake_client.tls_kwargs["ca_certs"], "/tmp/ca.pem")
|
|
self.assertEqual(fake_client.tls_kwargs["certfile"], "/tmp/client.pem")
|
|
self.assertEqual(fake_client.tls_kwargs["keyfile"], "/tmp/client.key")
|
|
self.assertTrue(fake_client.tls_insecure)
|
|
|
|
|
|
class TestAckPublish(unittest.TestCase):
|
|
def test_failed_ack_forces_non_null_error_fields(self):
|
|
fake_client = FakeMqttClient(rc=0)
|
|
ok = publish_command_ack(
|
|
fake_client,
|
|
"9b8d1856-ff34-4864-a726-12de072d0f77",
|
|
NIL_COMMAND_ID,
|
|
"failed",
|
|
error_code=None,
|
|
error_message=None,
|
|
)
|
|
self.assertTrue(ok)
|
|
self.assertEqual(len(fake_client.calls), 2)
|
|
payload = json.loads(fake_client.calls[0]["payload"])
|
|
self.assertEqual(payload["status"], "failed")
|
|
self.assertTrue(isinstance(payload["error_code"], str) and payload["error_code"])
|
|
self.assertTrue(isinstance(payload["error_message"], str) and payload["error_message"])
|
|
|
|
def test_retry_on_broker_disconnect_then_success(self):
|
|
# First loop (2 topics): NO_CONN, NO_CONN. Second loop: success, success.
|
|
fake_client = SequencedMqttClient([
|
|
mqtt.MQTT_ERR_NO_CONN,
|
|
mqtt.MQTT_ERR_NO_CONN,
|
|
mqtt.MQTT_ERR_SUCCESS,
|
|
mqtt.MQTT_ERR_SUCCESS,
|
|
])
|
|
future_expiry = (datetime.now(timezone.utc) + timedelta(seconds=30)).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
|
|
with patch("simclient.time.sleep", return_value=None) as sleep_mock:
|
|
ok = publish_command_ack(
|
|
fake_client,
|
|
"9b8d1856-ff34-4864-a726-12de072d0f77",
|
|
"5d1f8b4b-7e85-44fb-8f38-3f5d5da5e2e4",
|
|
"accepted",
|
|
expires_at=future_expiry,
|
|
)
|
|
|
|
self.assertTrue(ok)
|
|
self.assertEqual(len(fake_client.calls), 4)
|
|
sleep_mock.assert_called_once()
|
|
|
|
def test_stop_retry_when_expired(self):
|
|
fake_client = SequencedMqttClient([
|
|
mqtt.MQTT_ERR_NO_CONN,
|
|
mqtt.MQTT_ERR_NO_CONN,
|
|
])
|
|
past_expiry = (datetime.now(timezone.utc) - timedelta(seconds=30)).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
|
|
with patch("simclient.time.sleep", return_value=None) as sleep_mock:
|
|
ok = publish_command_ack(
|
|
fake_client,
|
|
"9b8d1856-ff34-4864-a726-12de072d0f77",
|
|
"5d1f8b4b-7e85-44fb-8f38-3f5d5da5e2e4",
|
|
"accepted",
|
|
expires_at=past_expiry,
|
|
)
|
|
|
|
self.assertFalse(ok)
|
|
self.assertEqual(len(fake_client.calls), 2)
|
|
sleep_mock.assert_not_called()
|
|
|
|
|
|
class TestProcessedCommandsState(unittest.TestCase):
|
|
def test_prune_keeps_recent_only(self):
|
|
recent = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
old = (datetime.now(timezone.utc) - timedelta(hours=30)).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
commands = {
|
|
"a": {"processed_at": recent, "status": "completed"},
|
|
"b": {"processed_at": old, "status": "completed"},
|
|
}
|
|
pruned = _prune_processed_commands(commands)
|
|
self.assertIn("a", pruned)
|
|
self.assertNotIn("b", pruned)
|
|
|
|
def test_load_and_persist_round_trip(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
state_file = os.path.join(tmpdir, "processed_commands.json")
|
|
with patch("simclient.PROCESSED_COMMANDS_FILE", state_file):
|
|
persist_processed_commands({
|
|
"x": {
|
|
"status": "completed",
|
|
"processed_at": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
|
}
|
|
})
|
|
loaded = load_processed_commands()
|
|
self.assertIn("x", loaded)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|