refactor: use new BaseMqttMixin from mqtt_helper

pull/106/head
Jeff Culverhouse 3 months ago
parent 8a9602af3b
commit 693eb1c542

@ -7,9 +7,9 @@
import asyncio import asyncio
import argparse import argparse
from json_logging import setup_logging, get_logger from json_logging import setup_logging, get_logger
from mqtt_helper import MqttError
from .core import Amcrest2Mqtt from .core import Amcrest2Mqtt
from .mixins.helpers import ConfigError from .mixins.helpers import ConfigError
from .mixins.mqtt import MqttError
def build_parser() -> argparse.ArgumentParser: def build_parser() -> argparse.ArgumentParser:

@ -1,137 +1,24 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
# Copyright (c) 2025 Jeff Culverhouse # Copyright (c) 2025 Jeff Culverhouse
import asyncio
from datetime import datetime, timedelta
import json import json
import paho.mqtt.client as mqtt from typing import TYPE_CHECKING, Any
from paho.mqtt.client import Client, MQTTMessage, ConnectFlags, DisconnectFlags
from paho.mqtt.enums import LogLevel from mqtt_helper import BaseMqttMixin
from paho.mqtt.properties import Properties from paho.mqtt.client import Client, MQTTMessage
from paho.mqtt.packettypes import PacketTypes
from paho.mqtt.reasoncodes import ReasonCode
from paho.mqtt.enums import CallbackAPIVersion
import ssl
from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from amcrest2mqtt.interface import AmcrestServiceProtocol as Amcrest2Mqtt from amcrest2mqtt.interface import AmcrestServiceProtocol as Amcrest2Mqtt
_T = TypeVar("_T")
class MqttError(ValueError):
"""Raised when the connection to the MQTT server fails"""
pass
class MqttMixin:
async def mqttc_create(self: Amcrest2Mqtt) -> None:
# lets use a new client_id for each connection attempt
self.client_id = self.mqtt_helper.client_id()
self.mqttc = mqtt.Client(
client_id=self.client_id,
callback_api_version=CallbackAPIVersion.VERSION2,
reconnect_on_failure=False,
protocol=mqtt.MQTTv5,
)
if self.mqtt_config.get("tls_enabled"):
self.mqttc.tls_set(
ca_certs=self.mqtt_config.get("tls_ca_cert"),
certfile=self.mqtt_config.get("tls_cert"),
keyfile=self.mqtt_config.get("tls_key"),
cert_reqs=ssl.CERT_REQUIRED,
tls_version=ssl.PROTOCOL_TLS,
)
if self.mqtt_config.get("username") or self.mqtt_config.get("password"):
self.mqttc.username_pw_set(
username=self.mqtt_config.get("username", ""),
password=self.mqtt_config.get("password", ""),
)
self.mqttc.on_connect = self._wrap_async(self.mqtt_on_connect)
self.mqttc.on_disconnect = self._wrap_async(self.mqtt_on_disconnect)
self.mqttc.on_message = self._wrap_async(self.mqtt_on_message)
self.mqttc.on_subscribe = self._wrap_async(self.mqtt_on_subscribe)
self.mqttc.on_log = self._wrap_async(self.mqtt_on_log)
# Define a "last will" message (LWT):
self.mqttc.will_set(self.mqtt_helper.avty_t("service"), "offline", qos=1, retain=True)
try: class MqttMixin(BaseMqttMixin):
host = self.mqtt_config["host"] def mqtt_subscription_topics(self: Amcrest2Mqtt) -> list[str]:
port = self.mqtt_config["port"] return [
self.logger.info(f"connecting to MQTT broker at {host}:{port} as client_id: {self.client_id}") "homeassistant/status",
f"{self.mqtt_helper.service_slug}/service/+/set",
props = Properties(PacketTypes.CONNECT) f"{self.mqtt_helper.service_slug}/service/+/command",
props.SessionExpiryInterval = 0 f"{self.mqtt_helper.service_slug}/+/switch/+/set",
f"{self.mqtt_helper.service_slug}/+/button/+/set",
self.mqttc.connect(host=host, port=port, keepalive=60, properties=props) ]
self.logger.info(f"connected to {host} MQTT broker")
self.mqtt_connect_time = datetime.now()
self.mqttc.loop_start()
except ConnectionError as err:
self.logger.error(f"failed to connect to MQTT host {host}: {err}")
self.running = False
raise SystemExit(1)
except Exception as err:
self.logger.error(f"network problem trying to connect to MQTT host {host}: {err}")
self.running = False
raise SystemExit(1)
def _wrap_async(
self: Amcrest2Mqtt,
coro_func: Callable[..., Coroutine[Any, Any, _T]],
) -> Callable[..., None]:
def wrapper(*args: Any, **kwargs: Any) -> None:
self.loop.call_soon_threadsafe(lambda: asyncio.create_task(coro_func(*args, **kwargs)))
return wrapper
async def mqtt_on_connect(
self: Amcrest2Mqtt, client: Client, userdata: dict[str, Any], flags: ConnectFlags, reason_code: ReasonCode, properties: Properties | None
) -> None:
# send our helper the client
self.mqtt_helper.set_client(client)
if reason_code.value != 0:
raise MqttError(f"MQTT failed to connect ({reason_code.getName()})")
await self.publish_service_discovery()
await self.publish_service_availability()
await self.publish_service_state()
self.logger.debug("subscribing to topics on MQTT")
client.subscribe("homeassistant/status")
client.subscribe(f"{self.mqtt_helper.service_slug}/service/+/set")
client.subscribe(f"{self.mqtt_helper.service_slug}/service/+/command")
client.subscribe(f"{self.mqtt_helper.service_slug}/+/switch/+/set")
client.subscribe(f"{self.mqtt_helper.service_slug}/+/button/+/set")
async def mqtt_on_disconnect(
self: Amcrest2Mqtt, client: Client, userdata: Any, flags: DisconnectFlags, reason_code: ReasonCode, properties: Properties | None
) -> None:
# clear the client on our helper
self.mqtt_helper.clear_client()
if reason_code.value != 0:
self.logger.error(f"Mqtt lost connection ({reason_code.getName()})")
else:
self.logger.info("closed Mqtt connection")
if self.running and (self.mqtt_connect_time is None or datetime.now() > self.mqtt_connect_time + timedelta(seconds=10)):
await self.mqttc_create()
else:
self.logger.info("Mqtt disconnect — stopping service loop")
self.running = False
async def mqtt_on_log(self: Amcrest2Mqtt, client: Client, userdata: Any, paho_log_level: int, msg: str) -> None:
if paho_log_level == LogLevel.MQTT_LOG_ERR:
self.logger.error(f"Mqtt logged: {msg}")
if paho_log_level == LogLevel.MQTT_LOG_WARNING:
self.logger.warning(f"Mqtt logged: {msg}")
async def mqtt_on_message(self: Amcrest2Mqtt, client: Client, userdata: Any, msg: MQTTMessage) -> None: async def mqtt_on_message(self: Amcrest2Mqtt, client: Client, userdata: Any, msg: MQTTMessage) -> None:
topic = msg.topic topic = msg.topic
@ -205,10 +92,3 @@ class MqttMixin:
except ValueError as err: except ValueError as err:
self.logger.warning(f"Ignoring malformed topic {topic}: {err}") self.logger.warning(f"Ignoring malformed topic {topic}: {err}")
return [] return []
async def mqtt_on_subscribe(
self: Amcrest2Mqtt, client: Client, userdata: Any, mid: int, reason_code_list: list[ReasonCode], properties: Properties
) -> None:
reason_names = [rc.getName() for rc in reason_code_list]
joined = "; ".join(reason_names) if reason_names else "none"
self.logger.debug(f"Mqtt subscribed (mid={mid}): {joined}")

@ -410,7 +410,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/93/4b/979db9e44be09f71e
[[package]] [[package]]
name = "mqtt-helper-graystorm" name = "mqtt-helper-graystorm"
version = "0.1.0" version = "0.1.0"
source = { git = "https://github.com/weirdtangent/mqtt-helper.git?branch=main#c4034f68f2492173ec0ff13d94eeab47d2bc7c09" } source = { git = "https://github.com/weirdtangent/mqtt-helper.git?branch=main#e81f183a38bdd965099cacfdc1286781f5d7d8f6" }
dependencies = [ dependencies = [
{ name = "logging" }, { name = "logging" },
{ name = "paho-mqtt" }, { name = "paho-mqtt" },

Loading…
Cancel
Save