chore: move safe_publish to our helper

pull/106/head
Jeff Culverhouse 3 months ago
parent 52193bfc8b
commit 11e18b1b73

@ -104,7 +104,6 @@ class AmcrestServiceProtocol(Protocol):
def mqtt_on_message(self, client: Client, userdata: Any, msg: MQTTMessage) -> None: ... def mqtt_on_message(self, client: Client, userdata: Any, msg: MQTTMessage) -> None: ...
def mqtt_on_subscribe(self, client: Client, userdata: Any, mid: int, reason_code_list: list[ReasonCode], properties: Properties) -> None: ... def mqtt_on_subscribe(self, client: Client, userdata: Any, mid: int, reason_code_list: list[ReasonCode], properties: Properties) -> None: ...
def mqtt_on_log(self, client: Client, userdata: Any, paho_log_level: int, msg: str) -> None: ... def mqtt_on_log(self, client: Client, userdata: Any, paho_log_level: int, msg: str) -> None: ...
def mqtt_safe_publish(self, topic: str, payload: str | bool | int | dict | None, **kwargs: Any) -> None: ...
def mqttc_create(self) -> None: ... def mqttc_create(self) -> None: ...
def publish_device_availability(self, device_id: str, online: bool = True) -> None: ... def publish_device_availability(self, device_id: str, online: bool = True) -> None: ...
def publish_device_discovery(self, device_id: str) -> None: ... def publish_device_discovery(self, device_id: str) -> None: ...

@ -3,14 +3,14 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json import json
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
from paho.mqtt.client import Client, MQTTMessage, PayloadType, ConnectFlags, DisconnectFlags from paho.mqtt.client import Client, MQTTMessage, ConnectFlags, DisconnectFlags
from paho.mqtt.enums import LogLevel from paho.mqtt.enums import LogLevel
from paho.mqtt.properties import Properties from paho.mqtt.properties import Properties
from paho.mqtt.packettypes import PacketTypes from paho.mqtt.packettypes import PacketTypes
from paho.mqtt.reasoncodes import ReasonCode from paho.mqtt.reasoncodes import ReasonCode
from paho.mqtt.enums import CallbackAPIVersion from paho.mqtt.enums import CallbackAPIVersion
import ssl import ssl
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from amcrest2mqtt.interface import AmcrestServiceProtocol as Amcrest2Mqtt from amcrest2mqtt.interface import AmcrestServiceProtocol as Amcrest2Mqtt
@ -79,6 +79,9 @@ class MqttMixin:
def mqtt_on_connect( def mqtt_on_connect(
self: Amcrest2Mqtt, client: Client, userdata: dict[str, Any], flags: ConnectFlags, reason_code: ReasonCode, properties: Properties | None self: Amcrest2Mqtt, client: Client, userdata: dict[str, Any], flags: ConnectFlags, reason_code: ReasonCode, properties: Properties | None
) -> None: ) -> None:
# send our helper the client
self.mqtt_helper.set_client(client)
if reason_code.value != 0: if reason_code.value != 0:
raise MqttError(f"MQTT failed to connect ({reason_code.getName()})") raise MqttError(f"MQTT failed to connect ({reason_code.getName()})")
@ -95,6 +98,9 @@ class MqttMixin:
def mqtt_on_disconnect( def mqtt_on_disconnect(
self: Amcrest2Mqtt, client: Client, userdata: Any, flags: DisconnectFlags, reason_code: ReasonCode, properties: Properties | None self: Amcrest2Mqtt, client: Client, userdata: Any, flags: DisconnectFlags, reason_code: ReasonCode, properties: Properties | None
) -> None: ) -> None:
# clear the client on our helper
self.mqtt_helper.clear_client()
if reason_code.value != 0: if reason_code.value != 0:
self.logger.error(f"MQTT lost connection ({reason_code.getName()})") self.logger.error(f"MQTT lost connection ({reason_code.getName()})")
else: else:
@ -202,19 +208,3 @@ class MqttMixin:
reason_names = [rc.getName() for rc in reason_code_list] reason_names = [rc.getName() for rc in reason_code_list]
joined = "; ".join(reason_names) if reason_names else "none" joined = "; ".join(reason_names) if reason_names else "none"
self.logger.debug(f"MQTT subscribed (mid={mid}): {joined}") self.logger.debug(f"MQTT subscribed (mid={mid}): {joined}")
def mqtt_safe_publish(self: Amcrest2Mqtt, topic: str, payload: str | bool | int | dict | None, **kwargs: Any) -> None:
if not topic:
raise ValueError("Cannot post to a blank topic")
if isinstance(payload, dict) and ("component" in payload or "//////" in payload):
self.logger.warning("Questionable payload includes 'component' or string of slashes - wont't send to HA")
self.logger.warning(f"topic: {topic}")
self.logger.warning(f"payload: {payload}")
raise ValueError("Possible invalid payload. topic: {topic} payload: {payload}")
try:
if payload is None:
self.mqttc.publish(topic, "null", **kwargs)
else:
self.mqttc.publish(topic, cast(PayloadType, payload), **kwargs)
except Exception as e:
self.logger.warning(f"MQTT publish failed for {topic} with {payload[:120] if isinstance(payload, str) else payload}: {e}")

@ -18,7 +18,7 @@ class PublishMixin:
) )
self.logger.info("Publishing service entity") self.logger.info("Publishing service entity")
self.mqtt_safe_publish( self.mqtt_helper.safe_publish(
topic=self.mqtt_helper.disc_t("binary_sensor", "service"), topic=self.mqtt_helper.disc_t("binary_sensor", "service"),
payload=json.dumps( payload=json.dumps(
{ {
@ -40,7 +40,7 @@ class PublishMixin:
retain=True, retain=True,
) )
self.mqtt_safe_publish( self.mqtt_helper.safe_publish(
topic=self.mqtt_helper.disc_t("sensor", "api_calls"), topic=self.mqtt_helper.disc_t("sensor", "api_calls"),
payload=json.dumps( payload=json.dumps(
{ {
@ -58,7 +58,7 @@ class PublishMixin:
qos=self.mqtt_config["qos"], qos=self.mqtt_config["qos"],
retain=True, retain=True,
) )
self.mqtt_safe_publish( self.mqtt_helper.safe_publish(
topic=self.mqtt_helper.disc_t("binary_sensor", "rate_limited"), topic=self.mqtt_helper.disc_t("binary_sensor", "rate_limited"),
payload=json.dumps( payload=json.dumps(
{ {
@ -77,7 +77,7 @@ class PublishMixin:
qos=self.mqtt_config["qos"], qos=self.mqtt_config["qos"],
retain=True, retain=True,
) )
self.mqtt_safe_publish( self.mqtt_helper.safe_publish(
topic=self.mqtt_helper.disc_t("number", "storage_refresh"), topic=self.mqtt_helper.disc_t("number", "storage_refresh"),
payload=json.dumps( payload=json.dumps(
{ {
@ -98,7 +98,7 @@ class PublishMixin:
qos=self.mqtt_config["qos"], qos=self.mqtt_config["qos"],
retain=True, retain=True,
) )
self.mqtt_safe_publish( self.mqtt_helper.safe_publish(
topic=self.mqtt_helper.disc_t("number", "device_list_refresh"), topic=self.mqtt_helper.disc_t("number", "device_list_refresh"),
payload=json.dumps( payload=json.dumps(
{ {
@ -119,7 +119,7 @@ class PublishMixin:
qos=self.mqtt_config["qos"], qos=self.mqtt_config["qos"],
retain=True, retain=True,
) )
self.mqtt_safe_publish( self.mqtt_helper.safe_publish(
topic=self.mqtt_helper.disc_t("number", "snapshot_refresh"), topic=self.mqtt_helper.disc_t("number", "snapshot_refresh"),
payload=json.dumps( payload=json.dumps(
{ {
@ -140,7 +140,7 @@ class PublishMixin:
qos=self.mqtt_config["qos"], qos=self.mqtt_config["qos"],
retain=True, retain=True,
) )
self.mqtt_safe_publish( self.mqtt_helper.safe_publish(
topic=self.mqtt_helper.disc_t("button", "refresh_device_list"), topic=self.mqtt_helper.disc_t("button", "refresh_device_list"),
payload=json.dumps( payload=json.dumps(
{ {
@ -158,7 +158,7 @@ class PublishMixin:
self.logger.debug(f"[HA] Discovery published for {self.service} ({self.mqtt_helper.service_slug})") self.logger.debug(f"[HA] Discovery published for {self.service} ({self.mqtt_helper.service_slug})")
def publish_service_availability(self: Amcrest2Mqtt, status: str = "online") -> None: def publish_service_availability(self: Amcrest2Mqtt, status: str = "online") -> None:
self.mqtt_safe_publish(self.mqtt_helper.avty_t("service"), status, qos=self.qos, retain=True) self.mqtt_helper.safe_publish(self.mqtt_helper.avty_t("service"), status, qos=self.qos, retain=True)
def publish_service_state(self: Amcrest2Mqtt) -> None: def publish_service_state(self: Amcrest2Mqtt) -> None:
service = { service = {
@ -177,7 +177,7 @@ class PublishMixin:
else: else:
payload = json.dumps(value) payload = json.dumps(value)
self.mqtt_safe_publish( self.mqtt_helper.safe_publish(
self.mqtt_helper.stat_t("service", "service", key), self.mqtt_helper.stat_t("service", "service", key),
payload, payload,
qos=self.mqtt_config["qos"], qos=self.mqtt_config["qos"],
@ -198,7 +198,7 @@ class PublishMixin:
payload = {k: v for k, v in defn.items() if k != "component_type"} payload = {k: v for k, v in defn.items() if k != "component_type"}
# Publish discovery # Publish discovery
self.mqtt_safe_publish(topic, json.dumps(payload), retain=True) self.mqtt_helper.safe_publish(topic, json.dumps(payload), retain=True)
# Mark discovered in state (per published entity) # Mark discovered in state (per published entity)
self.states.setdefault(eff_device_id, {}).setdefault("internal", {})["discovered"] = 1 self.states.setdefault(eff_device_id, {}).setdefault("internal", {})["discovered"] = 1
@ -215,7 +215,7 @@ class PublishMixin:
payload = "online" if online else "offline" payload = "online" if online else "offline"
avty_t = self.get_device_availability_topic(device_id) avty_t = self.get_device_availability_topic(device_id)
self.mqtt_safe_publish(avty_t, payload, retain=True) self.mqtt_helper.safe_publish(avty_t, payload, retain=True)
def publish_device_state(self: Amcrest2Mqtt, device_id: str) -> None: def publish_device_state(self: Amcrest2Mqtt, device_id: str) -> None:
def _publish_one(dev_id: str, defn: str | dict[str, Any], suffix: str = "") -> None: def _publish_one(dev_id: str, defn: str | dict[str, Any], suffix: str = "") -> None:
@ -230,9 +230,9 @@ class PublishMixin:
meta = self.states[dev_id].get("meta") meta = self.states[dev_id].get("meta")
if isinstance(meta, dict) and "last_update" in meta: if isinstance(meta, dict) and "last_update" in meta:
flat["last_update"] = meta["last_update"] flat["last_update"] = meta["last_update"]
self.mqtt_safe_publish(topic, json.dumps(flat), retain=True) self.mqtt_helper.safe_publish(topic, json.dumps(flat), retain=True)
else: else:
self.mqtt_safe_publish(topic, defn, retain=True) self.mqtt_helper.safe_publish(topic, defn, retain=True)
if not self.is_discovered(device_id): if not self.is_discovered(device_id):
self.logger.debug(f"[device state] Discovery not complete for {device_id} yet, holding off on sending state") self.logger.debug(f"[device state] Discovery not complete for {device_id} yet, holding off on sending state")

@ -320,7 +320,10 @@ wheels = [
[[package]] [[package]]
name = "mqtt-helper-graystorm" name = "mqtt-helper-graystorm"
version = "0.1.0" version = "0.1.0"
source = { git = "https://github.com/weirdtangent/mqtt-helper.git?branch=main#8772e3b3aab4cefc1ace2fdc0f96911609f22eb3" } source = { git = "https://github.com/weirdtangent/mqtt-helper.git?branch=main#576567322c874c16b69b5f2d996313b58744ccf6" }
dependencies = [
{ name = "paho-mqtt" },
]
[[package]] [[package]]
name = "mypy" name = "mypy"

Loading…
Cancel
Save