""" Unit tests for reboot/shutdown command intake primitives. Run from project root (venv activated): python -m pytest tests/test_command_intake.py -v """ import os import sys import json import tempfile import unittest from datetime import datetime, timezone, timedelta from unittest.mock import patch sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from simclient import ( # noqa: E402 NIL_COMMAND_ID, command_requires_recovery_completion, command_mock_reboot_immediate_complete_enabled, configure_mqtt_security, mqtt, validate_command_payload, publish_command_ack, _prune_processed_commands, load_processed_commands, persist_processed_commands, ) class FakePublishResult: def __init__(self, rc): self.rc = rc class FakeMqttClient: def __init__(self, rc=0): self.rc = rc self.calls = [] def publish(self, topic, payload, qos=0, retain=False): self.calls.append({ "topic": topic, "payload": payload, "qos": qos, "retain": retain, }) return FakePublishResult(self.rc) class SequencedMqttClient: def __init__(self, rc_sequence): self._rc_sequence = list(rc_sequence) self.calls = [] def publish(self, topic, payload, qos=0, retain=False): rc = self._rc_sequence.pop(0) if self._rc_sequence else 0 self.calls.append({ "topic": topic, "payload": payload, "qos": qos, "retain": retain, "rc": rc, }) return FakePublishResult(rc) class FakeSecurityClient: def __init__(self): self.username = None self.password = None self.tls_kwargs = None self.tls_insecure = None def username_pw_set(self, username, password=None): self.username = username self.password = password def tls_set(self, **kwargs): self.tls_kwargs = kwargs def tls_insecure_set(self, enabled): self.tls_insecure = enabled def _valid_payload(seconds_valid=240): now = datetime.now(timezone.utc) exp = now + timedelta(seconds=seconds_valid) return { "schema_version": "1.0", "command_id": "5d1f8b4b-7e85-44fb-8f38-3f5d5da5e2e4", "client_uuid": "9b8d1856-ff34-4864-a726-12de072d0f77", "action": "reboot_host", "issued_at": now.strftime("%Y-%m-%dT%H:%M:%SZ"), "expires_at": exp.strftime("%Y-%m-%dT%H:%M:%SZ"), "requested_by": 1, "reason": "operator_request", } class TestValidateCommandPayload(unittest.TestCase): def test_accepts_valid_payload(self): payload = _valid_payload() ok, normalized, code, msg = validate_command_payload(payload, payload["client_uuid"]) self.assertTrue(ok) self.assertIsNone(code) self.assertIsNone(msg) self.assertEqual(normalized["action"], "reboot_host") def test_rejects_extra_fields(self): payload = _valid_payload() payload["extra"] = "x" ok, _, code, msg = validate_command_payload(payload, payload["client_uuid"]) self.assertFalse(ok) self.assertEqual(code, "invalid_schema") self.assertIn("unexpected fields", msg) def test_rejects_stale_command(self): payload = _valid_payload() old_issued = datetime.now(timezone.utc) - timedelta(hours=3) old_expires = datetime.now(timezone.utc) - timedelta(hours=2) payload["issued_at"] = old_issued.strftime("%Y-%m-%dT%H:%M:%SZ") payload["expires_at"] = old_expires.strftime("%Y-%m-%dT%H:%M:%SZ") ok, _, code, _ = validate_command_payload(payload, payload["client_uuid"]) self.assertFalse(ok) self.assertEqual(code, "stale_command") def test_rejects_action_outside_enum(self): payload = _valid_payload() payload["action"] = "restart_service" ok, _, code, msg = validate_command_payload(payload, payload["client_uuid"]) self.assertFalse(ok) self.assertEqual(code, "invalid_schema") self.assertIn("action must be one of", msg) def test_rejects_client_uuid_mismatch(self): payload = _valid_payload() ok, _, code, msg = validate_command_payload( payload, "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", ) self.assertFalse(ok) self.assertEqual(code, "invalid_schema") self.assertIn("client_uuid", msg) class TestCommandLifecyclePolicy(unittest.TestCase): def test_reboot_requires_recovery_completion(self): self.assertTrue(command_requires_recovery_completion("reboot_host")) self.assertFalse(command_requires_recovery_completion("shutdown_host")) def test_mock_reboot_immediate_completion_enabled_for_mock_helper(self): with patch("simclient.COMMAND_MOCK_REBOOT_IMMEDIATE_COMPLETE", True), \ patch("simclient.COMMAND_HELPER_PATH", "/home/pi/scripts/mock-command-helper.sh"): self.assertTrue(command_mock_reboot_immediate_complete_enabled("reboot_host")) def test_mock_reboot_immediate_completion_disabled_for_live_helper(self): with patch("simclient.COMMAND_MOCK_REBOOT_IMMEDIATE_COMPLETE", True), \ patch("simclient.COMMAND_HELPER_PATH", "/usr/local/bin/infoscreen-cmd-helper.sh"): self.assertFalse(command_mock_reboot_immediate_complete_enabled("reboot_host")) class TestMqttSecurityConfiguration(unittest.TestCase): def test_configure_username_password(self): fake_client = FakeSecurityClient() with patch("simclient.MQTT_USER", ""), \ patch("simclient.MQTT_PASSWORD_BROKER", ""), \ patch("simclient.MQTT_USERNAME", "client-user"), \ patch("simclient.MQTT_PASSWORD", "client-pass"), \ patch("simclient.MQTT_TLS_ENABLED", False): configured = configure_mqtt_security(fake_client) self.assertEqual(fake_client.username, "client-user") self.assertEqual(fake_client.password, "client-pass") self.assertFalse(configured["tls"]) def test_configure_tls(self): fake_client = FakeSecurityClient() with patch("simclient.MQTT_USER", ""), \ patch("simclient.MQTT_PASSWORD_BROKER", ""), \ patch("simclient.MQTT_USERNAME", ""), \ patch("simclient.MQTT_PASSWORD", ""), \ patch("simclient.MQTT_TLS_ENABLED", True), \ patch("simclient.MQTT_TLS_CA_CERT", "/tmp/ca.pem"), \ patch("simclient.MQTT_TLS_CERT", "/tmp/client.pem"), \ patch("simclient.MQTT_TLS_KEY", "/tmp/client.key"), \ patch("simclient.MQTT_TLS_INSECURE", True): configured = configure_mqtt_security(fake_client) self.assertTrue(configured["tls"]) self.assertEqual(fake_client.tls_kwargs["ca_certs"], "/tmp/ca.pem") self.assertEqual(fake_client.tls_kwargs["certfile"], "/tmp/client.pem") self.assertEqual(fake_client.tls_kwargs["keyfile"], "/tmp/client.key") self.assertTrue(fake_client.tls_insecure) class TestAckPublish(unittest.TestCase): def test_failed_ack_forces_non_null_error_fields(self): fake_client = FakeMqttClient(rc=0) ok = publish_command_ack( fake_client, "9b8d1856-ff34-4864-a726-12de072d0f77", NIL_COMMAND_ID, "failed", error_code=None, error_message=None, ) self.assertTrue(ok) self.assertEqual(len(fake_client.calls), 2) payload = json.loads(fake_client.calls[0]["payload"]) self.assertEqual(payload["status"], "failed") self.assertTrue(isinstance(payload["error_code"], str) and payload["error_code"]) self.assertTrue(isinstance(payload["error_message"], str) and payload["error_message"]) def test_retry_on_broker_disconnect_then_success(self): # First loop (2 topics): NO_CONN, NO_CONN. Second loop: success, success. fake_client = SequencedMqttClient([ mqtt.MQTT_ERR_NO_CONN, mqtt.MQTT_ERR_NO_CONN, mqtt.MQTT_ERR_SUCCESS, mqtt.MQTT_ERR_SUCCESS, ]) future_expiry = (datetime.now(timezone.utc) + timedelta(seconds=30)).strftime("%Y-%m-%dT%H:%M:%SZ") with patch("simclient.time.sleep", return_value=None) as sleep_mock: ok = publish_command_ack( fake_client, "9b8d1856-ff34-4864-a726-12de072d0f77", "5d1f8b4b-7e85-44fb-8f38-3f5d5da5e2e4", "accepted", expires_at=future_expiry, ) self.assertTrue(ok) self.assertEqual(len(fake_client.calls), 4) sleep_mock.assert_called_once() def test_stop_retry_when_expired(self): fake_client = SequencedMqttClient([ mqtt.MQTT_ERR_NO_CONN, mqtt.MQTT_ERR_NO_CONN, ]) past_expiry = (datetime.now(timezone.utc) - timedelta(seconds=30)).strftime("%Y-%m-%dT%H:%M:%SZ") with patch("simclient.time.sleep", return_value=None) as sleep_mock: ok = publish_command_ack( fake_client, "9b8d1856-ff34-4864-a726-12de072d0f77", "5d1f8b4b-7e85-44fb-8f38-3f5d5da5e2e4", "accepted", expires_at=past_expiry, ) self.assertFalse(ok) self.assertEqual(len(fake_client.calls), 2) sleep_mock.assert_not_called() class TestProcessedCommandsState(unittest.TestCase): def test_prune_keeps_recent_only(self): recent = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") old = (datetime.now(timezone.utc) - timedelta(hours=30)).strftime("%Y-%m-%dT%H:%M:%SZ") commands = { "a": {"processed_at": recent, "status": "completed"}, "b": {"processed_at": old, "status": "completed"}, } pruned = _prune_processed_commands(commands) self.assertIn("a", pruned) self.assertNotIn("b", pruned) def test_load_and_persist_round_trip(self): with tempfile.TemporaryDirectory() as tmpdir: state_file = os.path.join(tmpdir, "processed_commands.json") with patch("simclient.PROCESSED_COMMANDS_FILE", state_file): persist_processed_commands({ "x": { "status": "completed", "processed_at": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), } }) loaded = load_processed_commands() self.assertIn("x", loaded) if __name__ == "__main__": unittest.main()