feat(core): add async process pool, graceful signal handling, and safer config loading

- Added multiprocessing-based connection pooling for faster camera setup
- Implemented robust signal handling (SIGINT/SIGTERM) for graceful shutdown
- Upgraded MQTT client to v5 with improved logging and reconnect behavior
- Added defensive WebRTC config validation to avoid runtime errors
- Introduced unified config loader (`util.load_config`) with boolean coercion
- Improved IP resolution and logging consistency across modules
pull/106/head
Jeff Culverhouse 4 months ago
parent a69ffc1667
commit f025d60f75

@ -5,23 +5,20 @@
# #
# The software is provided 'as is', without any warranty. # The software is provided 'as is', without any warranty.
from amcrest import AmcrestCamera, AmcrestError, CommError, LoginError, exceptions from amcrest import AmcrestCamera, AmcrestError, CommError, LoginError
import asyncio import asyncio
from asyncio import timeout from concurrent.futures import ProcessPoolExecutor
import base64 import base64
from datetime import datetime
import httpx
import logging import logging
import os import signal
import time from util import get_ip_address, to_gb
from util import *
from zoneinfo import ZoneInfo
class AmcrestAPI(object):
class AmcrestAPI:
def __init__(self, config): def __init__(self, config):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
# we don't want to get this mess of deeper-level logging # Quiet down noisy loggers
logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore.http11").setLevel(logging.WARNING) logging.getLogger("httpcore.http11").setLevel(logging.WARNING)
logging.getLogger("httpcore.connection").setLevel(logging.WARNING) logging.getLogger("httpcore.connection").setLevel(logging.WARNING)
@ -29,96 +26,149 @@ class AmcrestAPI(object):
logging.getLogger("amcrest.event").setLevel(logging.WARNING) logging.getLogger("amcrest.event").setLevel(logging.WARNING)
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING) logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)
self.last_call_date = '' self.timezone = config["timezone"]
self.timezone = config['timezone'] self.amcrest_config = config["amcrest"]
self.amcrest_config = config['amcrest']
self.count = len(self.amcrest_config['hosts'])
self.devices = {} self.devices = {}
self.events = [] self.events = []
self.executor = ProcessPoolExecutor(
max_workers=min(8, len(self.amcrest_config["hosts"]))
)
# handle signals gracefully
signal.signal(signal.SIGINT, self._sig_handler)
signal.signal(signal.SIGTERM, self._sig_handler)
self._shutting_down = False
def _sig_handler(self, signum, frame):
if not self._shutting_down:
self._shutting_down = True
self.logger.warning("SIGINT received — shutting down process pool...")
self.shutdown()
def shutdown(self):
"""Instantly kill process pool workers."""
try:
if self.executor:
self.logger.debug("Force-terminating process pool workers...")
self.executor.shutdown(wait=False, cancel_futures=True)
self.executor = None
except Exception as e:
self.logger.warning(f"Error shutting down process pool: {e}")
def get_camera(self, host):
cfg = self.amcrest_config
return AmcrestCamera(
host, cfg["port"], cfg["username"], cfg["password"], verbose=False
).camera
# ----------------------------------------------------------------------------------------------
async def connect_to_devices(self): async def connect_to_devices(self):
self.logger.info(f'Connecting to: {self.amcrest_config["hosts"]}') self.logger.info(f"Connecting to: {self.amcrest_config['hosts']}")
# Defensive guard against shutdown signals
if getattr(self, "shutting_down", False):
self.logger.warning("Connect aborted: shutdown already in progress.")
return {}
tasks = [] tasks = [
for host in self.amcrest_config['hosts']: asyncio.create_task(self._connect_device_threaded(host, name))
device_name = self.amcrest_config['names'].pop(0) for host, name in zip(
task = asyncio.create_task(self.get_device(host, device_name)) self.amcrest_config["hosts"], self.amcrest_config["names"]
tasks.append(task) )
await asyncio.gather(*tasks, return_exceptions=True) ]
if len(self.devices) == 0: try:
self.logger.error('Failed to connect to all devices, exiting') results = await asyncio.gather(*tasks, return_exceptions=True)
except asyncio.CancelledError:
self.logger.warning("Device connection cancelled by signal.")
return {}
except Exception as e:
self.logger.error(f"Device connection failed: {e}", exc_info=True)
return {}
successes = [r for r in results if isinstance(r, dict) and "config" in r]
failures = [r for r in results if isinstance(r, dict) and "error" in r]
self.logger.info(
f"Device connection summary: {len(successes)} succeeded, {len(failures)} failed"
)
if not successes and not getattr(self, "shutting_down", False):
self.logger.error("Failed to connect to any devices, exiting")
exit(1) exit(1)
# return just the config of each device, not the camera object # Recreate cameras in parent process
return {d: self.devices[d]['config'] for d in self.devices.keys()} for info in successes:
cfg = info["config"]
serial = cfg["serial_number"]
cam = self.get_camera(cfg["host_ip"])
self.devices[serial] = {"camera": cam, "config": cfg}
def get_camera(self, host): return {d: self.devices[d]["config"] for d in self.devices.keys()}
config = self.amcrest_config
return AmcrestCamera(host, config['port'], config['username'], config['password'], verbose=False).camera
async def get_device(self, host, device_name): async def _connect_device_threaded(self, host, device_name):
try: """Run the blocking camera connection logic in a separate process."""
# resolve host and setup camera by ip so we aren't making 100k DNS lookups per day loop = asyncio.get_running_loop()
return await loop.run_in_executor(
self.executor, _connect_device_worker, (host, device_name, self.amcrest_config)
)
def _connect_device_sync(self, host, device_name):
"""Blocking version of connect logic that runs in a separate process."""
try: try:
import multiprocessing
p_name = multiprocessing.current_process().name
host_ip = get_ip_address(host) host_ip = get_ip_address(host)
self.logger.info(f'nslookup {host} got us {host_ip}')
camera = self.get_camera(host_ip) camera = self.get_camera(host_ip)
except Exception as err:
self.logger.error(f'Error with {host}: {err}')
device_type = camera.device_type.replace('type=', '').strip() device_type = camera.device_type.replace("type=", "").strip()
is_ad110 = device_type == 'AD110' is_ad110 = device_type == "AD110"
is_ad410 = device_type == 'AD410' is_ad410 = device_type == "AD410"
is_doorbell = is_ad110 or is_ad410 is_doorbell = is_ad110 or is_ad410
serial_number = camera.serial_number serial_number = camera.serial_number
if not isinstance(serial_number, str): version = camera.software_information[0].replace("version=", "").strip()
self.logger.error(f'Error fetching serial number for {host}: {camera.serial_number}')
exit(1)
version = camera.software_information[0].replace('version=', '').strip()
build = camera.software_information[1].strip() build = camera.software_information[1].strip()
sw_version = f'{version} ({build})' sw_version = f"{version} ({build})"
network_config = dict(item.split('=') for item in camera.network_config.splitlines()) network_config = dict(
interface = network_config['table.Network.DefaultInterface'] item.split("=") for item in camera.network_config.splitlines()
ip_address = network_config[f'table.Network.{interface}.IPAddress'] )
mac_address = network_config[f'table.Network.{interface}.PhysicalAddress'].upper() iface = network_config["table.Network.DefaultInterface"]
ip_address = network_config[f"table.Network.{iface}.IPAddress"]
action = 'Connected' if camera.serial_number not in self.devices else 'Reconnected' mac_address = network_config[f"table.Network.{iface}.PhysicalAddress"].upper()
self.logger.info(f'{action} to {host} as {camera.serial_number}')
print(f"[{p_name}] Connected to {host} ({ip_address}) as {serial_number}")
self.devices[serial_number] = {
'camera': camera, return {
'config': { "config": {
'host': host, "host": host,
'host_ip': host_ip, "host_ip": host_ip,
'device_name': device_name, "device_name": device_name,
'device_type': device_type, "device_type": device_type,
'device_class': camera.device_class, "device_class": camera.device_class,
'is_ad110': is_ad110, "is_ad110": is_ad110,
'is_ad410': is_ad410, "is_ad410": is_ad410,
'is_doorbell': is_doorbell, "is_doorbell": is_doorbell,
'serial_number': serial_number, "serial_number": serial_number,
'software_version': sw_version, "software_version": sw_version,
'hardware_version': camera.hardware_version, "hardware_version": camera.hardware_version,
'vendor': camera.vendor_information, "vendor": camera.vendor_information,
'network': { "network": {
'interface': interface, "interface": iface,
'ip_address': ip_address, "ip_address": ip_address,
'mac': mac_address, "mac": mac_address,
}
}, },
} }
self.get_privacy_mode(serial_number) }
except LoginError as err: except Exception as e:
self.logger.error(f'Invalid username/password to connect to device "{host}", fix in config.yaml') import traceback
except AmcrestError as err: err_trace = traceback.format_exc()
self.logger.error(f'Failed to connect to device "{host}", check config.yaml and restart to try again: {err}') print(f"[child] Error connecting to {host}: {e}\n{err_trace}")
return {"error": f"{e}", "host": host}
# Storage stats ------------------------------------------------------------------------------- # Storage stats -------------------------------------------------------------------------------
@ -152,7 +202,6 @@ class AmcrestAPI(object):
return privacy_mode return privacy_mode
def set_privacy_mode(self, device_id, switch): def set_privacy_mode(self, device_id, switch):
device = self.devices[device_id] device = self.devices[device_id]
@ -245,7 +294,6 @@ class AmcrestAPI(object):
if tries == 3: if tries == 3:
self.logger.error(f'Failed to communicate with device ({device_id}) to get recorded file') self.logger.error(f'Failed to communicate with device ({device_id}) to get recorded file')
# Events -------------------------------------------------------------------------------------- # Events --------------------------------------------------------------------------------------
async def collect_all_device_events(self): async def collect_all_device_events(self):
@ -321,3 +369,69 @@ class AmcrestAPI(object):
def get_next_event(self): def get_next_event(self):
return self.events.pop(0) if len(self.events) > 0 else None return self.events.pop(0) if len(self.events) > 0 else None
def _connect_device_worker(args):
"""Top-level helper so it can be pickled by ProcessPoolExecutor."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
host, device_name, amcrest_cfg = args
try:
host_ip = get_ip_address(host)
camera = AmcrestCamera(
host_ip,
amcrest_cfg["port"],
amcrest_cfg["username"],
amcrest_cfg["password"],
verbose=False,
).camera
device_type = camera.device_type.replace("type=", "").strip()
is_ad110 = device_type == "AD110"
is_ad410 = device_type == "AD410"
is_doorbell = is_ad110 or is_ad410
serial_number = camera.serial_number
version = camera.software_information[0].replace("version=", "").strip()
build = camera.software_information[1].strip()
sw_version = f"{version} ({build})"
network_config = dict(
item.split("=") for item in camera.network_config.splitlines()
)
iface = network_config["table.Network.DefaultInterface"]
ip_address = network_config[f"table.Network.{iface}.IPAddress"]
mac_address = network_config[f"table.Network.{iface}.PhysicalAddress"].upper()
print(f"[worker] Connected to {host} ({ip_address}) as {serial_number}")
return {
"config": {
"host": host,
"host_ip": host_ip,
"device_name": device_name,
"device_type": device_type,
"device_class": camera.device_class,
"is_ad110": is_ad110,
"is_ad410": is_ad410,
"is_doorbell": is_doorbell,
"serial_number": serial_number,
"software_version": sw_version,
"hardware_version": camera.hardware_version,
"vendor": camera.vendor_information,
"network": {
"interface": iface,
"ip_address": ip_address,
"mac": mac_address,
},
}
}
except Exception as e:
import traceback
err_trace = traceback.format_exc()
print(f"[worker] Error connecting to {host}: {e}\n{err_trace}")
return {"error": str(e), "host": host}

@ -63,18 +63,19 @@ class AmcrestMqtt(object):
# MQTT Functions ------------------------------------------------------------------------------ # MQTT Functions ------------------------------------------------------------------------------
def mqtt_on_connect(self, client, userdata, flags, rc, properties): def mqtt_on_connect(self, client, userdata, flags, reason_code, properties):
if rc != 0: if reason_code.value != 0:
self.logger.error(f'MQTT connection issue ({rc})') self.logger.error(f'MQTT connection issue ({reason_code.getName()})')
exit() self.running = False
return
self.logger.info(f'MQTT connected as {self.client_id}') self.logger.info(f'MQTT connected as {self.client_id}')
client.subscribe("homeassistant/status") client.subscribe("homeassistant/status")
client.subscribe(self.get_device_sub_topic()) client.subscribe(self.get_device_sub_topic())
client.subscribe(self.get_attribute_sub_topic()) client.subscribe(self.get_attribute_sub_topic())
def mqtt_on_disconnect(self, client, userdata, flags, rc, properties): def mqtt_on_disconnect(self, client, userdata, disconnect_flags, reason_code, properties):
self.logger.info('MQTT connection closed') self.logger.warning(f'MQTT disconnected: {reason_code.getName()} (flags={disconnect_flags})')
self.mqttc.loop_stop() self.mqttc.loop_stop()
if self.running and time.time() > self.mqtt_connect_time + 10: if self.running and time.time() > self.mqtt_connect_time + 10:
@ -88,7 +89,6 @@ class AmcrestMqtt(object):
self.paused = False self.paused = False
else: else:
self.running = False self.running = False
exit()
def mqtt_on_log(self, client, userdata, paho_log_level, msg): def mqtt_on_log(self, client, userdata, paho_log_level, msg):
if paho_log_level == mqtt.LogLevel.MQTT_LOG_ERR: if paho_log_level == mqtt.LogLevel.MQTT_LOG_ERR:
@ -169,10 +169,10 @@ class AmcrestMqtt(object):
def mqttc_create(self): def mqttc_create(self):
self.mqttc = mqtt.Client( self.mqttc = mqtt.Client(
callback_api_version=mqtt.CallbackAPIVersion.VERSION2,
client_id=self.client_id, client_id=self.client_id,
clean_session=False, callback_api_version=mqtt.CallbackAPIVersion.VERSION2,
reconnect_on_failure=False, reconnect_on_failure=False,
protocol=mqtt.MQTTv5,
) )
if self.mqtt_config.get('tls_enabled'): if self.mqtt_config.get('tls_enabled'):
@ -183,7 +183,8 @@ class AmcrestMqtt(object):
cert_reqs=ssl.CERT_REQUIRED, cert_reqs=ssl.CERT_REQUIRED,
tls_version=ssl.PROTOCOL_TLS, tls_version=ssl.PROTOCOL_TLS,
) )
else: self.mqttc.tls_insecure_set(self.mqtt_config.get("tls_insecure", False))
if self.mqtt_config.get('username'):
self.mqttc.username_pw_set( self.mqttc.username_pw_set(
username=self.mqtt_config.get('username'), username=self.mqtt_config.get('username'),
password=self.mqtt_config.get('password'), password=self.mqtt_config.get('password'),
@ -199,16 +200,24 @@ class AmcrestMqtt(object):
self.mqttc.will_set(self.get_discovery_topic('service', 'availability'), payload="offline", qos=self.mqtt_config['qos'], retain=True) self.mqttc.will_set(self.get_discovery_topic('service', 'availability'), payload="offline", qos=self.mqtt_config['qos'], retain=True)
try: try:
self.logger.info(
f"Connecting to MQTT broker at {self.mqtt_config.get('host')}:{self.mqtt_config.get('port')} "
f"as {self.client_id}"
)
self.mqttc.connect( self.mqttc.connect(
self.mqtt_config.get('host'), host=self.mqtt_config.get('host'),
port=self.mqtt_config.get('port'), port=self.mqtt_config.get('port'),
keepalive=60, keepalive=60,
) )
self.mqtt_connect_time = time.time() self.mqtt_connect_time = time.time()
self.mqttc.loop_start() self.mqttc.loop_start()
except ConnectionError as error: except Exception as error:
self.logger.error(f'COULD NOT CONNECT TO MQTT {self.mqtt_config.get("host")}: {error}') self.logger.error(
exit(1) f"Failed to connect to MQTT broker {self.mqtt_config.get('host')}:{self.mqtt_config.get('port')} "
f"({type(error).__name__}: {error})",
exc_info=True,
)
self.running = False
# MQTT Topics --------------------------------------------------------------------------------- # MQTT Topics ---------------------------------------------------------------------------------
@ -261,6 +270,18 @@ class AmcrestMqtt(object):
return f"{self.mqtt_config['prefix']}/{self.get_component_slug(device_id)}/{topic}/{subtopic}" return f"{self.mqtt_config['prefix']}/{self.get_component_slug(device_id)}/{topic}/{subtopic}"
return f"{self.mqtt_config['discovery_prefix']}/device/{self.get_component_slug(device_id)}/{topic}/{subtopic}" return f"{self.mqtt_config['discovery_prefix']}/device/{self.get_component_slug(device_id)}/{topic}/{subtopic}"
def ha_cfg_topic(self, domain: str, object_id: str) -> str:
dp = self.mqtt_config.get('discovery_prefix', 'homeassistant')
return f"{dp}/{domain}/{object_id}/config"
def service_oid(self, suffix: str) -> str:
return f"{self.service_slug}_{suffix}"
def svc_topic(self, sub: str) -> str:
# Runtime topics (your own prefix), e.g. govee2mqtt/govee-service/state
pfx = self.mqtt_config.get('prefix', 'govee2mqtt')
return f"{pfx}/amcrest-service/{sub}"
# Service Device ------------------------------------------------------------------------------ # Service Device ------------------------------------------------------------------------------
def publish_service_state(self): def publish_service_state(self):
@ -474,14 +495,33 @@ class AmcrestMqtt(object):
'value_template': '{{ value_json.state }}', 'value_template': '{{ value_json.state }}',
'unique_id': self.get_slug(device_id, 'snapshot_camera'), 'unique_id': self.get_slug(device_id, 'snapshot_camera'),
} }
if 'webrtc' in self.amcrest_config: # --- Safe WebRTC config handling ----------------------------------------
webrtc_config = self.amcrest_config['webrtc'] webrtc_config = self.amcrest_config.get("webrtc")
rtc_host = webrtc_config['host']
rtc_port = webrtc_config['port'] # Handle missing, boolean, or incomplete configs gracefully
rtc_link = webrtc_config['link'] if isinstance(webrtc_config, bool) or not webrtc_config:
rtc_source = webrtc_config['sources'].pop(0) self.logger.debug("No valid WebRTC config found; skipping WebRTC setup.")
rtc_url = f'http://{rtc_host}:{rtc_port}/{rtc_link}?src={rtc_source}' else:
device_config['device']['configuration_url'] = rtc_url try:
rtc_host = webrtc_config.get("host")
rtc_port = webrtc_config.get("port")
rtc_link = webrtc_config.get("link")
rtc_sources = webrtc_config.get("sources", [])
rtc_source = rtc_sources[0] if rtc_sources else None
if rtc_host and rtc_port and rtc_link and rtc_source:
rtc_url = f"http://{rtc_host}:{rtc_port}/{rtc_link}?src={rtc_source}"
device_config["device"]["configuration_url"] = rtc_url
self.logger.debug(f"Added WebRTC config URL for {device_id}: {rtc_url}")
else:
self.logger.warning(
f"Incomplete WebRTC config for {device_id}: {webrtc_config}"
)
except Exception as e:
self.logger.warning(
f"Failed to apply WebRTC config for {device_id}: {e}", exc_info=True
)
# copy the snapshot camera for the eventshot camera, with a couple of changes # copy the snapshot camera for the eventshot camera, with a couple of changes
components[self.get_slug(device_id, 'event_camera')] = { components[self.get_slug(device_id, 'event_camera')] = {
@ -744,9 +784,9 @@ class AmcrestMqtt(object):
def handle_service_message(self, attribute, message): def handle_service_message(self, attribute, message):
match attribute: match attribute:
case 'storage_refresh': case "storage_refresh":
self.storage_update_interval = message self.storage_update_interval = message
self.logger.info(f'Updated STORAGE_REFRESH_INTERVAL to be {message}') self.logger.info(f"Updated STORAGE_REFRESH_INTERVAL to be {message}")
case 'snapshot_refresh': case 'snapshot_refresh':
self.snapshot_update_interval = message self.snapshot_update_interval = message
self.logger.info(f'Updated SNAPSHOT_REFRESH_INTERVAL to be {message}') self.logger.info(f'Updated SNAPSHOT_REFRESH_INTERVAL to be {message}')
@ -839,30 +879,85 @@ class AmcrestMqtt(object):
# main loop # main loop
async def main_loop(self): async def main_loop(self):
"""Main event loop for Amcrest MQTT service."""
await self.setup_devices() await self.setup_devices()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# Create async tasks with descriptive names
tasks = [ tasks = [
asyncio.create_task(self.collect_storage_info()), asyncio.create_task(
asyncio.create_task(self.collect_events()), self.collect_storage_info(), name="collect_storage_info"
asyncio.create_task(self.check_event_queue()), ),
asyncio.create_task(self.collect_snapshots()), asyncio.create_task(self.collect_events(), name="collect_events"),
asyncio.create_task(self.check_event_queue(), name="check_event_queue"),
asyncio.create_task(self.collect_snapshots(), name="collect_snapshots"),
] ]
# setup signal handling for tasks # Graceful signal handler
def _signal_handler(signame):
"""Immediate, aggressive shutdown handler for Ctrl+C or SIGTERM."""
self.logger.warning(f"{signame} received — initiating shutdown NOW...")
self.running = False
# Cancel *all* asyncio tasks, even those not tracked manually
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
if not task.done():
task.cancel(f"{signame} received")
# Force-stop ProcessPoolExecutor if present
try:
if hasattr(self, "api") and hasattr(self.api, "executor"):
self.logger.debug("Force-shutting down process pool...")
self.api.executor.shutdown(wait=False, cancel_futures=True)
except Exception as e:
self.logger.debug(f"Error force-stopping process pool: {e}")
# Stop the loop immediately after a short delay
loop.call_later(0.05, loop.stop)
for sig in (signal.SIGINT, signal.SIGTERM): for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler( try:
sig, lambda: asyncio.create_task(self._handle_signals(sig.name, loop)) loop.add_signal_handler(sig, _signal_handler, sig.name)
) except NotImplementedError:
# Windows compatibility
self.logger.debug(f"Signal handling not supported on this platform.")
try: try:
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception): # Handle task exceptions individually
for t, result in zip(tasks, results):
if isinstance(result, asyncio.CancelledError):
self.logger.info(f"Task '{t.get_name()}' cancelled.")
elif isinstance(result, Exception):
self.logger.error(
f"Task '{t.get_name()}' raised an exception: {result}",
exc_info=True,
)
self.running = False self.running = False
self.logger.error(f'Caught exception: {err}', exc_info=True)
except asyncio.CancelledError: except asyncio.CancelledError:
exit(1) self.logger.info("Main loop cancelled; shutting down...")
except Exception as err: except Exception as err:
self.logger.exception(f"Unhandled exception in main loop: {err}")
self.running = False self.running = False
self.logger.error(f'Caught exception: {err}') finally:
self.logger.info("All loops terminated, performing final cleanup...")
try:
# Save final state or cleanup hooks if needed
if hasattr(self, "save_state"):
self.save_state()
except Exception as e:
self.logger.warning(f"Error during save_state: {e}")
# Disconnect MQTT cleanly
if self.mqttc and self.mqttc.is_connected():
try:
self.mqttc.disconnect()
except Exception as e:
self.logger.warning(f"Error during MQTT disconnect: {e}")
self.logger.info("Main loop complete.")

189
app.py

@ -1,138 +1,61 @@
# This software is licensed under the MIT License, which allows you to use, #!/usr/bin/env python3
# copy, modify, merge, publish, distribute, and sell copies of the software,
# with the requirement to include the original copyright notice and this
# permission notice in all copies or substantial portions of the software.
#
# The software is provided 'as is', without any warranty.
import asyncio import asyncio
import argparse import argparse
from amcrest_mqtt import AmcrestMqtt
import logging import logging
import os from amcrest_mqtt import AmcrestMqtt
import sys from util import load_config
import time
from util import * if __name__ == "__main__":
import yaml # Parse command-line arguments
argparser = argparse.ArgumentParser()
# Let's go! argparser.add_argument(
version = read_version() "-c",
"--config",
# Cmd-line args
argparser = argparse.ArgumentParser()
argparser.add_argument(
'-c',
'--config',
required=False, required=False,
help='Directory holding config.yaml or full path to config file', help="Directory or file path for config.yaml (defaults to /config/config.yaml)",
) )
args = argparser.parse_args() args = argparser.parse_args()
# Setup config from yaml file or env # Load configuration
configpath = args.config or '/config' config = load_config(args.config)
try:
if not configpath.endswith('.yaml'): # Setup logging
if not configpath.endswith('/'): logging.basicConfig(
configpath += '/' format=(
configfile = configpath + 'config.yaml' "%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s: %(message)s"
with open(configfile) as file: if not config["hide_ts"]
config = yaml.safe_load(file) else "[%(levelname)s] %(name)s: %(message)s"
config['config_path'] = configpath ),
config['config_from'] = 'file' datefmt="%Y-%m-%d %H:%M:%S",
except: level=logging.DEBUG if config["debug"] else logging.INFO,
config = { )
'mqtt': {
'host': os.getenv('MQTT_HOST') or 'localhost', logger = logging.getLogger(__name__)
'qos': int(os.getenv('MQTT_QOS') or 0), logger.info(f"Starting amcrest2mqtt {config['version']}")
'port': int(os.getenv('MQTT_PORT') or 1883), logger.info(f"Config loaded from {config['config_from']} ({config['config_path']})")
'username': os.getenv('MQTT_USERNAME'),
'password': os.getenv('MQTT_PASSWORD'), # can be None # Run main loop safely
'tls_enabled': os.getenv('MQTT_TLS_ENABLED') == 'true', try:
'tls_ca_cert': os.getenv('MQTT_TLS_CA_CERT'), with AmcrestMqtt(config) as mqtt:
'tls_cert': os.getenv('MQTT_TLS_CERT'), try:
'tls_key': os.getenv('MQTT_TLS_KEY'), # Prefer a clean async run, but handle nested event loops
'prefix': os.getenv('MQTT_PREFIX') or 'amcrest2mqtt',
'homeassistant': os.getenv('MQTT_HOMEASSISTANT') == True,
'discovery_prefix': os.getenv('MQTT_DISCOVERY_PREFIX') or 'homeassistant',
},
'amcrest': {
'hosts': os.getenv("AMCREST_HOSTS"),
'names': os.getenv("AMCREST_NAMES"),
'port': int(os.getenv("AMCREST_PORT") or 80),
'username': os.getenv("AMCREST_USERNAME") or "admin",
'password': os.getenv("AMCREST_PASSWORD"),
'storage_update_interval': int(os.getenv("STORAGE_UPDATE_INTERVAL") or 900),
'snapshot_update_interval': int(os.getenv("SNAPSHOT_UPDATE_INTERVAL") or 300),
'webrtc': {
'host': os.getenv("AMCREST_WEBRTC_HOST"),
'port': int(os.getenv("AMCREST_WEBRTC_PORT") or 1984),
'link': os.getenv("AMCREST_WEBRTC_LINK") or 'stream.html',
'sources': os.getenv("AMCREST_WEBRTC_SOURCES"),
},
},
'debug': True if os.getenv('DEBUG') else False,
'hide_ts': True if os.getenv('HIDE_TS') else False,
'timezone': os.getenv('TZ'),
'config_from': 'env',
}
config['version'] = version
config['configpath'] = os.path.dirname(configpath)
# defaults
if 'username' not in config['mqtt']: config['mqtt']['username'] = ''
if 'password' not in config['mqtt']: config['mqtt']['password'] = ''
if 'qos' not in config['mqtt']: config['mqtt']['qos'] = 0
if 'timezone' not in config: config['timezone'] = 'UTC'
if 'debug' not in config: config['debug'] = os.getenv('DEBUG') or False
if 'hide_ts' not in config: config['hide_ts'] = os.getenv('HIDE_TS') or False
# init logging, based on config settings
logging.basicConfig(
format = '%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s: %(message)s' if config['hide_ts'] == False else '[%(levelname)s] %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO if config['debug'] == False else logging.DEBUG
)
logger = logging.getLogger(__name__)
logger.info(f'Starting: amcrest2mqtt {version}')
logger.info(f'Config loaded from {config["config_from"]}')
# Check for required config properties
if config['amcrest']['hosts'] is None:
logger.error('Missing env var: AMCREST_HOSTS or amcrest.hosts in config')
exit(1)
config['amcrest']['host_count'] = len(config['amcrest']['hosts'])
if config['amcrest']['names'] is None:
logger.error('Missing env var: AMCREST_NAMES or amcrest.names in config')
exit(1)
config['amcrest']['name_count'] = len(config['amcrest']['names'])
if config['amcrest']['host_count'] != config['amcrest']['name_count']:
logger.error('The AMCREST_HOSTS and AMCREST_NAMES must have the same number of space-delimited hosts/names')
exit(1)
logger.info(f'Found {config["amcrest"]["host_count"]} host(s) defined to monitor')
if 'webrtc' in config['amcrest']:
webrtc = config['amcrest']['webrtc']
if 'host' not in webrtc:
logger.error('Missing HOST in webrtc config')
exit(1)
if 'sources' not in webrtc:
logger.error('Missing SOURCES in webrtc config')
exit(1)
config['amcrest']['webrtc_sources_count'] = len(config['amcrest']['webrtc']['sources'])
if config['amcrest']['host_count'] != config['amcrest']['webrtc_sources_count']:
logger.error('The AMCREST_HOSTS and AMCREST_WEBRTC_SOURCES must have the same number of space-delimited hosts/names')
exit(1)
if 'port' not in webrtc: webrtc['port'] = 1984
if 'link' not in webrtc: webrtc['link'] = 'stream.html'
if config['amcrest']['password'] is None:
logger.error('Please set the AMCREST_PASSWORD environment variable')
exit(1)
logger.debug("DEBUG logging is ON")
# Go!
with AmcrestMqtt(config) as mqtt:
asyncio.run(mqtt.main_loop()) asyncio.run(mqtt.main_loop())
except RuntimeError as e:
if "asyncio.run() cannot be called from a running event loop" in str(e):
loop = asyncio.get_event_loop()
loop.run_until_complete(mqtt.main_loop())
else:
raise
except KeyboardInterrupt:
logger.info("Shutdown requested (Ctrl+C). Exiting gracefully...")
except asyncio.CancelledError:
logger.warning("Main loop cancelled.")
except Exception as e:
logger.exception(f"Unhandled exception in main loop: {e}")
finally:
try:
if "mqtt" in locals() and hasattr(mqtt, "api") and mqtt.api:
mqtt.api.shutdown()
except Exception as e:
logger.debug(f"Error during shutdown: {e}")
logger.info("amcrest2mqtt stopped.")

@ -1,44 +1,121 @@
# This software is licensed under the MIT License, which allows you to use, import logging
# copy, modify, merge, publish, distribute, and sell copies of the software,
# with the requirement to include the original copyright notice and this
# permission notice in all copies or substantial portions of the software.
#
# The software is provided 'as is', without any warranty.
import ipaddress
import os import os
import socket import socket
import yaml
# Helper functions and callbacks def get_ip_address(hostname: str) -> str:
def read_file(file_name): """
with open(file_name, 'r') as file: Resolve a hostname to an IP address (IPv4 or IPv6).
data = file.read().replace('\n', '')
return data Returns:
str: The resolved IP address, or the original hostname if resolution fails.
"""
if not hostname:
return hostname
def read_version(): try:
if os.path.isfile('./VERSION'): # Try both IPv4 and IPv6 (AF_UNSPEC)
return read_file('./VERSION') infos = socket.getaddrinfo(
hostname, None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
# Prefer IPv4 addresses if available
for family, _, _, _, sockaddr in infos:
if family == socket.AF_INET:
return sockaddr[0]
# Otherwise, fallback to first valid IPv6
return infos[0][4][0] if infos else hostname
except socket.gaierror as e:
logging.debug(f"DNS lookup failed for {hostname}: {e}")
except Exception as e:
logging.debug(f"Unexpected error resolving {hostname}: {e}")
return read_file('../VERSION') return hostname
def to_gb(total):
return str(round(float(total[0]) / 1024 / 1024 / 1024, 2))
def is_ipv4(string): def to_gb(bytes_value):
"""Convert bytes to a rounded string in gigabytes."""
try: try:
ipaddress.IPv4Network(string) gb = float(bytes_value) / (1024**3)
return True return f"{gb:.2f} GB"
except ValueError: except Exception:
return False return "0.00 GB"
def get_ip_address(string): def read_file(file_name, strip_newlines=True, default=None, encoding="utf-8"):
if is_ipv4(string):
return string
try: try:
for i in socket.getaddrinfo(string, None): with open(file_name, "r", encoding=encoding) as f:
if i[0] == socket.AddressFamily.AF_INET: data = f.read()
return i[4][0] return data.replace("\n", "") if strip_newlines else data
except socket.gaierror as e: except FileNotFoundError:
raise Exception(f"Failed to resolve {string}: {e}") if default is not None:
raise Exception(f"Failed to find IP address for {string}") return default
raise
def read_version():
base_dir = os.path.dirname(os.path.abspath(__file__))
version_path = os.path.join(base_dir, "VERSION")
try:
with open(version_path, "r") as f:
return f.read().strip() or "unknown"
except FileNotFoundError:
env_version = os.getenv("APP_VERSION")
return env_version.strip() if env_version else "unknown"
def load_config(path=None):
"""Load and normalize configuration from YAML file or directory."""
logger = logging.getLogger(__name__)
default_path = "/config/config.yaml"
# Resolve config path
config_path = path or default_path
if os.path.isdir(config_path):
config_path = os.path.join(config_path, "config.yaml")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(config_path, "r") as f:
config = yaml.safe_load(f) or {}
# --- Normalization helpers ------------------------------------------------
def to_bool(value):
"""Coerce common truthy/falsey forms to proper bool."""
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if value is None:
return False
return str(value).strip().lower() in ("true", "1", "yes", "on")
def normalize_section(section, bool_keys):
"""Normalize booleans within a nested section."""
if not isinstance(config.get(section), dict):
return
for key in bool_keys:
if key in config[section]:
config[section][key] = to_bool(config[section][key])
# --- Global defaults + normalization -------------------------------------
config.setdefault("version", "1.0.0")
config.setdefault("debug", False)
config.setdefault("hide_ts", False)
config.setdefault("timezone", "UTC")
# normalize top-level flags
config["debug"] = to_bool(config.get("debug"))
config["hide_ts"] = to_bool(config.get("hide_ts"))
# Example: normalize booleans within sections
normalize_section("mqtt", ["tls", "retain", "clean_session"])
normalize_section("amcrest", ["webrtc", "verify_ssl"])
normalize_section("service", ["enabled", "auto_restart"])
# Add metadata for debugging/logging
config["config_path"] = os.path.abspath(config_path)
config["config_from"] = "file" if os.path.exists(config_path) else "defaults"
return config

Loading…
Cancel
Save