From 693eb1c5428c6484eca1c74e3b860c6f4530459e Mon Sep 17 00:00:00 2001 From: Jeff Culverhouse Date: Sun, 16 Nov 2025 05:29:23 -0500 Subject: [PATCH] refactor: use new BaseMqttMixin from mqtt_helper --- src/amcrest2mqtt/app.py | 2 +- src/amcrest2mqtt/mixins/mqtt.py | 146 +++----------------------------- uv.lock | 2 +- 3 files changed, 15 insertions(+), 135 deletions(-) diff --git a/src/amcrest2mqtt/app.py b/src/amcrest2mqtt/app.py index 013a1bb..1fbb806 100644 --- a/src/amcrest2mqtt/app.py +++ b/src/amcrest2mqtt/app.py @@ -7,9 +7,9 @@ import asyncio import argparse from json_logging import setup_logging, get_logger +from mqtt_helper import MqttError from .core import Amcrest2Mqtt from .mixins.helpers import ConfigError -from .mixins.mqtt import MqttError def build_parser() -> argparse.ArgumentParser: diff --git a/src/amcrest2mqtt/mixins/mqtt.py b/src/amcrest2mqtt/mixins/mqtt.py index 48244a8..66053ec 100644 --- a/src/amcrest2mqtt/mixins/mqtt.py +++ b/src/amcrest2mqtt/mixins/mqtt.py @@ -1,137 +1,24 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Jeff Culverhouse -import asyncio -from datetime import datetime, timedelta import json -import paho.mqtt.client as mqtt -from paho.mqtt.client import Client, MQTTMessage, ConnectFlags, DisconnectFlags -from paho.mqtt.enums import LogLevel -from paho.mqtt.properties import Properties -from paho.mqtt.packettypes import PacketTypes -from paho.mqtt.reasoncodes import ReasonCode -from paho.mqtt.enums import CallbackAPIVersion -import ssl -from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar +from typing import TYPE_CHECKING, Any + +from mqtt_helper import BaseMqttMixin +from paho.mqtt.client import Client, MQTTMessage if TYPE_CHECKING: from amcrest2mqtt.interface import AmcrestServiceProtocol as Amcrest2Mqtt -_T = TypeVar("_T") - - -class MqttError(ValueError): - """Raised when the connection to the MQTT server fails""" - - pass - - -class MqttMixin: - async def mqttc_create(self: Amcrest2Mqtt) -> None: - # lets use a new client_id for each connection attempt - self.client_id = self.mqtt_helper.client_id() - self.mqttc = mqtt.Client( - client_id=self.client_id, - callback_api_version=CallbackAPIVersion.VERSION2, - reconnect_on_failure=False, - protocol=mqtt.MQTTv5, - ) - - if self.mqtt_config.get("tls_enabled"): - self.mqttc.tls_set( - ca_certs=self.mqtt_config.get("tls_ca_cert"), - certfile=self.mqtt_config.get("tls_cert"), - keyfile=self.mqtt_config.get("tls_key"), - cert_reqs=ssl.CERT_REQUIRED, - tls_version=ssl.PROTOCOL_TLS, - ) - if self.mqtt_config.get("username") or self.mqtt_config.get("password"): - self.mqttc.username_pw_set( - username=self.mqtt_config.get("username", ""), - password=self.mqtt_config.get("password", ""), - ) - - self.mqttc.on_connect = self._wrap_async(self.mqtt_on_connect) - self.mqttc.on_disconnect = self._wrap_async(self.mqtt_on_disconnect) - self.mqttc.on_message = self._wrap_async(self.mqtt_on_message) - self.mqttc.on_subscribe = self._wrap_async(self.mqtt_on_subscribe) - self.mqttc.on_log = self._wrap_async(self.mqtt_on_log) - - # Define a "last will" message (LWT): - self.mqttc.will_set(self.mqtt_helper.avty_t("service"), "offline", qos=1, retain=True) - try: - host = self.mqtt_config["host"] - port = self.mqtt_config["port"] - self.logger.info(f"connecting to MQTT broker at {host}:{port} as client_id: {self.client_id}") - - props = Properties(PacketTypes.CONNECT) - props.SessionExpiryInterval = 0 - - self.mqttc.connect(host=host, port=port, keepalive=60, properties=props) - self.logger.info(f"connected to {host} MQTT broker") - - self.mqtt_connect_time = datetime.now() - self.mqttc.loop_start() - except ConnectionError as err: - self.logger.error(f"failed to connect to MQTT host {host}: {err}") - self.running = False - raise SystemExit(1) - except Exception as err: - self.logger.error(f"network problem trying to connect to MQTT host {host}: {err}") - self.running = False - raise SystemExit(1) - - def _wrap_async( - self: Amcrest2Mqtt, - coro_func: Callable[..., Coroutine[Any, Any, _T]], - ) -> Callable[..., None]: - def wrapper(*args: Any, **kwargs: Any) -> None: - self.loop.call_soon_threadsafe(lambda: asyncio.create_task(coro_func(*args, **kwargs))) - - return wrapper - - async def mqtt_on_connect( - self: Amcrest2Mqtt, client: Client, userdata: dict[str, Any], flags: ConnectFlags, reason_code: ReasonCode, properties: Properties | None - ) -> None: - # send our helper the client - self.mqtt_helper.set_client(client) - - if reason_code.value != 0: - raise MqttError(f"MQTT failed to connect ({reason_code.getName()})") - - await self.publish_service_discovery() - await self.publish_service_availability() - await self.publish_service_state() - - self.logger.debug("subscribing to topics on MQTT") - client.subscribe("homeassistant/status") - client.subscribe(f"{self.mqtt_helper.service_slug}/service/+/set") - client.subscribe(f"{self.mqtt_helper.service_slug}/service/+/command") - client.subscribe(f"{self.mqtt_helper.service_slug}/+/switch/+/set") - client.subscribe(f"{self.mqtt_helper.service_slug}/+/button/+/set") - - async def mqtt_on_disconnect( - self: Amcrest2Mqtt, client: Client, userdata: Any, flags: DisconnectFlags, reason_code: ReasonCode, properties: Properties | None - ) -> None: - # clear the client on our helper - self.mqtt_helper.clear_client() - - if reason_code.value != 0: - self.logger.error(f"Mqtt lost connection ({reason_code.getName()})") - else: - self.logger.info("closed Mqtt connection") - - if self.running and (self.mqtt_connect_time is None or datetime.now() > self.mqtt_connect_time + timedelta(seconds=10)): - await self.mqttc_create() - else: - self.logger.info("Mqtt disconnect — stopping service loop") - self.running = False - - async def mqtt_on_log(self: Amcrest2Mqtt, client: Client, userdata: Any, paho_log_level: int, msg: str) -> None: - if paho_log_level == LogLevel.MQTT_LOG_ERR: - self.logger.error(f"Mqtt logged: {msg}") - if paho_log_level == LogLevel.MQTT_LOG_WARNING: - self.logger.warning(f"Mqtt logged: {msg}") +class MqttMixin(BaseMqttMixin): + def mqtt_subscription_topics(self: Amcrest2Mqtt) -> list[str]: + return [ + "homeassistant/status", + f"{self.mqtt_helper.service_slug}/service/+/set", + f"{self.mqtt_helper.service_slug}/service/+/command", + f"{self.mqtt_helper.service_slug}/+/switch/+/set", + f"{self.mqtt_helper.service_slug}/+/button/+/set", + ] async def mqtt_on_message(self: Amcrest2Mqtt, client: Client, userdata: Any, msg: MQTTMessage) -> None: topic = msg.topic @@ -205,10 +92,3 @@ class MqttMixin: except ValueError as err: self.logger.warning(f"Ignoring malformed topic {topic}: {err}") return [] - - async def mqtt_on_subscribe( - self: Amcrest2Mqtt, client: Client, userdata: Any, mid: int, reason_code_list: list[ReasonCode], properties: Properties - ) -> None: - reason_names = [rc.getName() for rc in reason_code_list] - joined = "; ".join(reason_names) if reason_names else "none" - self.logger.debug(f"Mqtt subscribed (mid={mid}): {joined}") diff --git a/uv.lock b/uv.lock index c3392dc..5fed360 100644 --- a/uv.lock +++ b/uv.lock @@ -410,7 +410,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/93/4b/979db9e44be09f71e [[package]] name = "mqtt-helper-graystorm" version = "0.1.0" -source = { git = "https://github.com/weirdtangent/mqtt-helper.git?branch=main#c4034f68f2492173ec0ff13d94eeab47d2bc7c09" } +source = { git = "https://github.com/weirdtangent/mqtt-helper.git?branch=main#e81f183a38bdd965099cacfdc1286781f5d7d8f6" } dependencies = [ { name = "logging" }, { name = "paho-mqtt" },