refactor: use new BaseMqttMixin from mqtt_helper

pull/106/head
Jeff Culverhouse 3 months ago
parent 8a9602af3b
commit 693eb1c542

@ -7,9 +7,9 @@
import asyncio
import argparse
from json_logging import setup_logging, get_logger
from mqtt_helper import MqttError
from .core import Amcrest2Mqtt
from .mixins.helpers import ConfigError
from .mixins.mqtt import MqttError
def build_parser() -> argparse.ArgumentParser:

@ -1,137 +1,24 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Jeff Culverhouse
import asyncio
from datetime import datetime, timedelta
import json
import paho.mqtt.client as mqtt
from paho.mqtt.client import Client, MQTTMessage, ConnectFlags, DisconnectFlags
from paho.mqtt.enums import LogLevel
from paho.mqtt.properties import Properties
from paho.mqtt.packettypes import PacketTypes
from paho.mqtt.reasoncodes import ReasonCode
from paho.mqtt.enums import CallbackAPIVersion
import ssl
from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar
from typing import TYPE_CHECKING, Any
from mqtt_helper import BaseMqttMixin
from paho.mqtt.client import Client, MQTTMessage
if TYPE_CHECKING:
from amcrest2mqtt.interface import AmcrestServiceProtocol as Amcrest2Mqtt
_T = TypeVar("_T")
class MqttError(ValueError):
"""Raised when the connection to the MQTT server fails"""
pass
class MqttMixin:
async def mqttc_create(self: Amcrest2Mqtt) -> None:
# lets use a new client_id for each connection attempt
self.client_id = self.mqtt_helper.client_id()
self.mqttc = mqtt.Client(
client_id=self.client_id,
callback_api_version=CallbackAPIVersion.VERSION2,
reconnect_on_failure=False,
protocol=mqtt.MQTTv5,
)
if self.mqtt_config.get("tls_enabled"):
self.mqttc.tls_set(
ca_certs=self.mqtt_config.get("tls_ca_cert"),
certfile=self.mqtt_config.get("tls_cert"),
keyfile=self.mqtt_config.get("tls_key"),
cert_reqs=ssl.CERT_REQUIRED,
tls_version=ssl.PROTOCOL_TLS,
)
if self.mqtt_config.get("username") or self.mqtt_config.get("password"):
self.mqttc.username_pw_set(
username=self.mqtt_config.get("username", ""),
password=self.mqtt_config.get("password", ""),
)
self.mqttc.on_connect = self._wrap_async(self.mqtt_on_connect)
self.mqttc.on_disconnect = self._wrap_async(self.mqtt_on_disconnect)
self.mqttc.on_message = self._wrap_async(self.mqtt_on_message)
self.mqttc.on_subscribe = self._wrap_async(self.mqtt_on_subscribe)
self.mqttc.on_log = self._wrap_async(self.mqtt_on_log)
# Define a "last will" message (LWT):
self.mqttc.will_set(self.mqtt_helper.avty_t("service"), "offline", qos=1, retain=True)
try:
host = self.mqtt_config["host"]
port = self.mqtt_config["port"]
self.logger.info(f"connecting to MQTT broker at {host}:{port} as client_id: {self.client_id}")
props = Properties(PacketTypes.CONNECT)
props.SessionExpiryInterval = 0
self.mqttc.connect(host=host, port=port, keepalive=60, properties=props)
self.logger.info(f"connected to {host} MQTT broker")
self.mqtt_connect_time = datetime.now()
self.mqttc.loop_start()
except ConnectionError as err:
self.logger.error(f"failed to connect to MQTT host {host}: {err}")
self.running = False
raise SystemExit(1)
except Exception as err:
self.logger.error(f"network problem trying to connect to MQTT host {host}: {err}")
self.running = False
raise SystemExit(1)
def _wrap_async(
self: Amcrest2Mqtt,
coro_func: Callable[..., Coroutine[Any, Any, _T]],
) -> Callable[..., None]:
def wrapper(*args: Any, **kwargs: Any) -> None:
self.loop.call_soon_threadsafe(lambda: asyncio.create_task(coro_func(*args, **kwargs)))
return wrapper
async def mqtt_on_connect(
self: Amcrest2Mqtt, client: Client, userdata: dict[str, Any], flags: ConnectFlags, reason_code: ReasonCode, properties: Properties | None
) -> None:
# send our helper the client
self.mqtt_helper.set_client(client)
if reason_code.value != 0:
raise MqttError(f"MQTT failed to connect ({reason_code.getName()})")
await self.publish_service_discovery()
await self.publish_service_availability()
await self.publish_service_state()
self.logger.debug("subscribing to topics on MQTT")
client.subscribe("homeassistant/status")
client.subscribe(f"{self.mqtt_helper.service_slug}/service/+/set")
client.subscribe(f"{self.mqtt_helper.service_slug}/service/+/command")
client.subscribe(f"{self.mqtt_helper.service_slug}/+/switch/+/set")
client.subscribe(f"{self.mqtt_helper.service_slug}/+/button/+/set")
async def mqtt_on_disconnect(
self: Amcrest2Mqtt, client: Client, userdata: Any, flags: DisconnectFlags, reason_code: ReasonCode, properties: Properties | None
) -> None:
# clear the client on our helper
self.mqtt_helper.clear_client()
if reason_code.value != 0:
self.logger.error(f"Mqtt lost connection ({reason_code.getName()})")
else:
self.logger.info("closed Mqtt connection")
if self.running and (self.mqtt_connect_time is None or datetime.now() > self.mqtt_connect_time + timedelta(seconds=10)):
await self.mqttc_create()
else:
self.logger.info("Mqtt disconnect — stopping service loop")
self.running = False
async def mqtt_on_log(self: Amcrest2Mqtt, client: Client, userdata: Any, paho_log_level: int, msg: str) -> None:
if paho_log_level == LogLevel.MQTT_LOG_ERR:
self.logger.error(f"Mqtt logged: {msg}")
if paho_log_level == LogLevel.MQTT_LOG_WARNING:
self.logger.warning(f"Mqtt logged: {msg}")
class MqttMixin(BaseMqttMixin):
def mqtt_subscription_topics(self: Amcrest2Mqtt) -> list[str]:
return [
"homeassistant/status",
f"{self.mqtt_helper.service_slug}/service/+/set",
f"{self.mqtt_helper.service_slug}/service/+/command",
f"{self.mqtt_helper.service_slug}/+/switch/+/set",
f"{self.mqtt_helper.service_slug}/+/button/+/set",
]
async def mqtt_on_message(self: Amcrest2Mqtt, client: Client, userdata: Any, msg: MQTTMessage) -> None:
topic = msg.topic
@ -205,10 +92,3 @@ class MqttMixin:
except ValueError as err:
self.logger.warning(f"Ignoring malformed topic {topic}: {err}")
return []
async def mqtt_on_subscribe(
self: Amcrest2Mqtt, client: Client, userdata: Any, mid: int, reason_code_list: list[ReasonCode], properties: Properties
) -> None:
reason_names = [rc.getName() for rc in reason_code_list]
joined = "; ".join(reason_names) if reason_names else "none"
self.logger.debug(f"Mqtt subscribed (mid={mid}): {joined}")

@ -410,7 +410,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/93/4b/979db9e44be09f71e
[[package]]
name = "mqtt-helper-graystorm"
version = "0.1.0"
source = { git = "https://github.com/weirdtangent/mqtt-helper.git?branch=main#c4034f68f2492173ec0ff13d94eeab47d2bc7c09" }
source = { git = "https://github.com/weirdtangent/mqtt-helper.git?branch=main#e81f183a38bdd965099cacfdc1286781f5d7d8f6" }
dependencies = [
{ name = "logging" },
{ name = "paho-mqtt" },

Loading…
Cancel
Save