From f025d60f75913361a13047f886ee730a9a0579df Mon Sep 17 00:00:00 2001 From: Jeff Culverhouse Date: Fri, 10 Oct 2025 03:40:50 -0400 Subject: [PATCH] feat(core): add async process pool, graceful signal handling, and safer config loading - Added multiprocessing-based connection pooling for faster camera setup - Implemented robust signal handling (SIGINT/SIGTERM) for graceful shutdown - Upgraded MQTT client to v5 with improved logging and reconnect behavior - Added defensive WebRTC config validation to avoid runtime errors - Introduced unified config loader (`util.load_config`) with boolean coercion - Improved IP resolution and logging consistency across modules --- amcrest_api.py | 286 +++++++++++++++++++++++++++++++++--------------- amcrest_mqtt.py | 171 ++++++++++++++++++++++------- app.py | 193 ++++++++++---------------------- util.py | 143 ++++++++++++++++++------ 4 files changed, 501 insertions(+), 292 deletions(-) diff --git a/amcrest_api.py b/amcrest_api.py index 86e51f6..745004a 100644 --- a/amcrest_api.py +++ b/amcrest_api.py @@ -5,23 +5,20 @@ # # The software is provided 'as is', without any warranty. -from amcrest import AmcrestCamera, AmcrestError, CommError, LoginError, exceptions +from amcrest import AmcrestCamera, AmcrestError, CommError, LoginError import asyncio -from asyncio import timeout +from concurrent.futures import ProcessPoolExecutor import base64 -from datetime import datetime -import httpx import logging -import os -import time -from util import * -from zoneinfo import ZoneInfo +import signal +from util import get_ip_address, to_gb -class AmcrestAPI(object): + +class AmcrestAPI: def __init__(self, config): self.logger = logging.getLogger(__name__) - # we don't want to get this mess of deeper-level logging + # Quiet down noisy loggers logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore.http11").setLevel(logging.WARNING) logging.getLogger("httpcore.connection").setLevel(logging.WARNING) @@ -29,96 +26,149 @@ class AmcrestAPI(object): logging.getLogger("amcrest.event").setLevel(logging.WARNING) logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING) - self.last_call_date = '' - self.timezone = config['timezone'] - - self.amcrest_config = config['amcrest'] - - self.count = len(self.amcrest_config['hosts']) + self.timezone = config["timezone"] + self.amcrest_config = config["amcrest"] self.devices = {} self.events = [] + self.executor = ProcessPoolExecutor( + max_workers=min(8, len(self.amcrest_config["hosts"])) + ) + # handle signals gracefully + signal.signal(signal.SIGINT, self._sig_handler) + signal.signal(signal.SIGTERM, self._sig_handler) + self._shutting_down = False + + def _sig_handler(self, signum, frame): + if not self._shutting_down: + self._shutting_down = True + self.logger.warning("SIGINT received — shutting down process pool...") + self.shutdown() + + def shutdown(self): + """Instantly kill process pool workers.""" + try: + if self.executor: + self.logger.debug("Force-terminating process pool workers...") + self.executor.shutdown(wait=False, cancel_futures=True) + self.executor = None + except Exception as e: + self.logger.warning(f"Error shutting down process pool: {e}") + + def get_camera(self, host): + cfg = self.amcrest_config + return AmcrestCamera( + host, cfg["port"], cfg["username"], cfg["password"], verbose=False + ).camera + + # ---------------------------------------------------------------------------------------------- + async def connect_to_devices(self): - self.logger.info(f'Connecting to: {self.amcrest_config["hosts"]}') + self.logger.info(f"Connecting to: {self.amcrest_config['hosts']}") - tasks = [] - for host in self.amcrest_config['hosts']: - device_name = self.amcrest_config['names'].pop(0) - task = asyncio.create_task(self.get_device(host, device_name)) - tasks.append(task) - await asyncio.gather(*tasks, return_exceptions=True) + # Defensive guard against shutdown signals + if getattr(self, "shutting_down", False): + self.logger.warning("Connect aborted: shutdown already in progress.") + return {} - if len(self.devices) == 0: - self.logger.error('Failed to connect to all devices, exiting') + tasks = [ + asyncio.create_task(self._connect_device_threaded(host, name)) + for host, name in zip( + self.amcrest_config["hosts"], self.amcrest_config["names"] + ) + ] + + try: + results = await asyncio.gather(*tasks, return_exceptions=True) + except asyncio.CancelledError: + self.logger.warning("Device connection cancelled by signal.") + return {} + except Exception as e: + self.logger.error(f"Device connection failed: {e}", exc_info=True) + return {} + + successes = [r for r in results if isinstance(r, dict) and "config" in r] + failures = [r for r in results if isinstance(r, dict) and "error" in r] + + self.logger.info( + f"Device connection summary: {len(successes)} succeeded, {len(failures)} failed" + ) + + if not successes and not getattr(self, "shutting_down", False): + self.logger.error("Failed to connect to any devices, exiting") exit(1) - # return just the config of each device, not the camera object - return {d: self.devices[d]['config'] for d in self.devices.keys()} + # Recreate cameras in parent process + for info in successes: + cfg = info["config"] + serial = cfg["serial_number"] + cam = self.get_camera(cfg["host_ip"]) + self.devices[serial] = {"camera": cam, "config": cfg} - def get_camera(self, host): - config = self.amcrest_config - return AmcrestCamera(host, config['port'], config['username'], config['password'], verbose=False).camera + return {d: self.devices[d]["config"] for d in self.devices.keys()} + + async def _connect_device_threaded(self, host, device_name): + """Run the blocking camera connection logic in a separate process.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + self.executor, _connect_device_worker, (host, device_name, self.amcrest_config) + ) - async def get_device(self, host, device_name): + def _connect_device_sync(self, host, device_name): + """Blocking version of connect logic that runs in a separate process.""" try: - # resolve host and setup camera by ip so we aren't making 100k DNS lookups per day - try: - host_ip = get_ip_address(host) - self.logger.info(f'nslookup {host} got us {host_ip}') - camera = self.get_camera(host_ip) - except Exception as err: - self.logger.error(f'Error with {host}: {err}') - - device_type = camera.device_type.replace('type=', '').strip() - is_ad110 = device_type == 'AD110' - is_ad410 = device_type == 'AD410' + import multiprocessing + p_name = multiprocessing.current_process().name + + host_ip = get_ip_address(host) + camera = self.get_camera(host_ip) + + device_type = camera.device_type.replace("type=", "").strip() + is_ad110 = device_type == "AD110" + is_ad410 = device_type == "AD410" is_doorbell = is_ad110 or is_ad410 serial_number = camera.serial_number - if not isinstance(serial_number, str): - self.logger.error(f'Error fetching serial number for {host}: {camera.serial_number}') - exit(1) - - version = camera.software_information[0].replace('version=', '').strip() + version = camera.software_information[0].replace("version=", "").strip() build = camera.software_information[1].strip() - sw_version = f'{version} ({build})' - - network_config = dict(item.split('=') for item in camera.network_config.splitlines()) - interface = network_config['table.Network.DefaultInterface'] - ip_address = network_config[f'table.Network.{interface}.IPAddress'] - mac_address = network_config[f'table.Network.{interface}.PhysicalAddress'].upper() - - action = 'Connected' if camera.serial_number not in self.devices else 'Reconnected' - self.logger.info(f'{action} to {host} as {camera.serial_number}') - - self.devices[serial_number] = { - 'camera': camera, - 'config': { - 'host': host, - 'host_ip': host_ip, - 'device_name': device_name, - 'device_type': device_type, - 'device_class': camera.device_class, - 'is_ad110': is_ad110, - 'is_ad410': is_ad410, - 'is_doorbell': is_doorbell, - 'serial_number': serial_number, - 'software_version': sw_version, - 'hardware_version': camera.hardware_version, - 'vendor': camera.vendor_information, - 'network': { - 'interface': interface, - 'ip_address': ip_address, - 'mac': mac_address, - } - }, + sw_version = f"{version} ({build})" + + network_config = dict( + item.split("=") for item in camera.network_config.splitlines() + ) + iface = network_config["table.Network.DefaultInterface"] + ip_address = network_config[f"table.Network.{iface}.IPAddress"] + mac_address = network_config[f"table.Network.{iface}.PhysicalAddress"].upper() + + print(f"[{p_name}] Connected to {host} ({ip_address}) as {serial_number}") + + return { + "config": { + "host": host, + "host_ip": host_ip, + "device_name": device_name, + "device_type": device_type, + "device_class": camera.device_class, + "is_ad110": is_ad110, + "is_ad410": is_ad410, + "is_doorbell": is_doorbell, + "serial_number": serial_number, + "software_version": sw_version, + "hardware_version": camera.hardware_version, + "vendor": camera.vendor_information, + "network": { + "interface": iface, + "ip_address": ip_address, + "mac": mac_address, + }, + } } - self.get_privacy_mode(serial_number) - except LoginError as err: - self.logger.error(f'Invalid username/password to connect to device "{host}", fix in config.yaml') - except AmcrestError as err: - self.logger.error(f'Failed to connect to device "{host}", check config.yaml and restart to try again: {err}') + except Exception as e: + import traceback + err_trace = traceback.format_exc() + print(f"[child] Error connecting to {host}: {e}\n{err_trace}") + return {"error": f"{e}", "host": host} # Storage stats ------------------------------------------------------------------------------- @@ -152,7 +202,6 @@ class AmcrestAPI(object): return privacy_mode - def set_privacy_mode(self, device_id, switch): device = self.devices[device_id] @@ -245,7 +294,6 @@ class AmcrestAPI(object): if tries == 3: self.logger.error(f'Failed to communicate with device ({device_id}) to get recorded file') - # Events -------------------------------------------------------------------------------------- async def collect_all_device_events(self): @@ -320,4 +368,70 @@ class AmcrestAPI(object): self.logger.error(f'Failed to process event from {device_id}: {err}', exc_info=True) def get_next_event(self): - return self.events.pop(0) if len(self.events) > 0 else None \ No newline at end of file + return self.events.pop(0) if len(self.events) > 0 else None + + +def _connect_device_worker(args): + """Top-level helper so it can be pickled by ProcessPoolExecutor.""" + + signal.signal(signal.SIGINT, signal.SIG_IGN) + + host, device_name, amcrest_cfg = args + + try: + host_ip = get_ip_address(host) + camera = AmcrestCamera( + host_ip, + amcrest_cfg["port"], + amcrest_cfg["username"], + amcrest_cfg["password"], + verbose=False, + ).camera + + device_type = camera.device_type.replace("type=", "").strip() + is_ad110 = device_type == "AD110" + is_ad410 = device_type == "AD410" + is_doorbell = is_ad110 or is_ad410 + + serial_number = camera.serial_number + version = camera.software_information[0].replace("version=", "").strip() + build = camera.software_information[1].strip() + sw_version = f"{version} ({build})" + + network_config = dict( + item.split("=") for item in camera.network_config.splitlines() + ) + iface = network_config["table.Network.DefaultInterface"] + ip_address = network_config[f"table.Network.{iface}.IPAddress"] + mac_address = network_config[f"table.Network.{iface}.PhysicalAddress"].upper() + + print(f"[worker] Connected to {host} ({ip_address}) as {serial_number}") + + return { + "config": { + "host": host, + "host_ip": host_ip, + "device_name": device_name, + "device_type": device_type, + "device_class": camera.device_class, + "is_ad110": is_ad110, + "is_ad410": is_ad410, + "is_doorbell": is_doorbell, + "serial_number": serial_number, + "software_version": sw_version, + "hardware_version": camera.hardware_version, + "vendor": camera.vendor_information, + "network": { + "interface": iface, + "ip_address": ip_address, + "mac": mac_address, + }, + } + } + + except Exception as e: + import traceback + + err_trace = traceback.format_exc() + print(f"[worker] Error connecting to {host}: {e}\n{err_trace}") + return {"error": str(e), "host": host} diff --git a/amcrest_mqtt.py b/amcrest_mqtt.py index fff7cd1..9410290 100644 --- a/amcrest_mqtt.py +++ b/amcrest_mqtt.py @@ -63,18 +63,19 @@ class AmcrestMqtt(object): # MQTT Functions ------------------------------------------------------------------------------ - def mqtt_on_connect(self, client, userdata, flags, rc, properties): - if rc != 0: - self.logger.error(f'MQTT connection issue ({rc})') - exit() + def mqtt_on_connect(self, client, userdata, flags, reason_code, properties): + if reason_code.value != 0: + self.logger.error(f'MQTT connection issue ({reason_code.getName()})') + self.running = False + return self.logger.info(f'MQTT connected as {self.client_id}') client.subscribe("homeassistant/status") client.subscribe(self.get_device_sub_topic()) client.subscribe(self.get_attribute_sub_topic()) - def mqtt_on_disconnect(self, client, userdata, flags, rc, properties): - self.logger.info('MQTT connection closed') + def mqtt_on_disconnect(self, client, userdata, disconnect_flags, reason_code, properties): + self.logger.warning(f'MQTT disconnected: {reason_code.getName()} (flags={disconnect_flags})') self.mqttc.loop_stop() if self.running and time.time() > self.mqtt_connect_time + 10: @@ -88,7 +89,6 @@ class AmcrestMqtt(object): self.paused = False else: self.running = False - exit() def mqtt_on_log(self, client, userdata, paho_log_level, msg): if paho_log_level == mqtt.LogLevel.MQTT_LOG_ERR: @@ -169,10 +169,10 @@ class AmcrestMqtt(object): def mqttc_create(self): self.mqttc = mqtt.Client( - callback_api_version=mqtt.CallbackAPIVersion.VERSION2, client_id=self.client_id, - clean_session=False, + callback_api_version=mqtt.CallbackAPIVersion.VERSION2, reconnect_on_failure=False, + protocol=mqtt.MQTTv5, ) if self.mqtt_config.get('tls_enabled'): @@ -183,7 +183,8 @@ class AmcrestMqtt(object): cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_TLS, ) - else: + self.mqttc.tls_insecure_set(self.mqtt_config.get("tls_insecure", False)) + if self.mqtt_config.get('username'): self.mqttc.username_pw_set( username=self.mqtt_config.get('username'), password=self.mqtt_config.get('password'), @@ -199,16 +200,24 @@ class AmcrestMqtt(object): self.mqttc.will_set(self.get_discovery_topic('service', 'availability'), payload="offline", qos=self.mqtt_config['qos'], retain=True) try: + self.logger.info( + f"Connecting to MQTT broker at {self.mqtt_config.get('host')}:{self.mqtt_config.get('port')} " + f"as {self.client_id}" + ) self.mqttc.connect( - self.mqtt_config.get('host'), + host=self.mqtt_config.get('host'), port=self.mqtt_config.get('port'), keepalive=60, ) self.mqtt_connect_time = time.time() self.mqttc.loop_start() - except ConnectionError as error: - self.logger.error(f'COULD NOT CONNECT TO MQTT {self.mqtt_config.get("host")}: {error}') - exit(1) + except Exception as error: + self.logger.error( + f"Failed to connect to MQTT broker {self.mqtt_config.get('host')}:{self.mqtt_config.get('port')} " + f"({type(error).__name__}: {error})", + exc_info=True, + ) + self.running = False # MQTT Topics --------------------------------------------------------------------------------- @@ -261,6 +270,18 @@ class AmcrestMqtt(object): return f"{self.mqtt_config['prefix']}/{self.get_component_slug(device_id)}/{topic}/{subtopic}" return f"{self.mqtt_config['discovery_prefix']}/device/{self.get_component_slug(device_id)}/{topic}/{subtopic}" + def ha_cfg_topic(self, domain: str, object_id: str) -> str: + dp = self.mqtt_config.get('discovery_prefix', 'homeassistant') + return f"{dp}/{domain}/{object_id}/config" + + def service_oid(self, suffix: str) -> str: + return f"{self.service_slug}_{suffix}" + + def svc_topic(self, sub: str) -> str: + # Runtime topics (your own prefix), e.g. govee2mqtt/govee-service/state + pfx = self.mqtt_config.get('prefix', 'govee2mqtt') + return f"{pfx}/amcrest-service/{sub}" + # Service Device ------------------------------------------------------------------------------ def publish_service_state(self): @@ -474,14 +495,33 @@ class AmcrestMqtt(object): 'value_template': '{{ value_json.state }}', 'unique_id': self.get_slug(device_id, 'snapshot_camera'), } - if 'webrtc' in self.amcrest_config: - webrtc_config = self.amcrest_config['webrtc'] - rtc_host = webrtc_config['host'] - rtc_port = webrtc_config['port'] - rtc_link = webrtc_config['link'] - rtc_source = webrtc_config['sources'].pop(0) - rtc_url = f'http://{rtc_host}:{rtc_port}/{rtc_link}?src={rtc_source}' - device_config['device']['configuration_url'] = rtc_url + # --- Safe WebRTC config handling ---------------------------------------- + webrtc_config = self.amcrest_config.get("webrtc") + + # Handle missing, boolean, or incomplete configs gracefully + if isinstance(webrtc_config, bool) or not webrtc_config: + self.logger.debug("No valid WebRTC config found; skipping WebRTC setup.") + else: + try: + rtc_host = webrtc_config.get("host") + rtc_port = webrtc_config.get("port") + rtc_link = webrtc_config.get("link") + rtc_sources = webrtc_config.get("sources", []) + rtc_source = rtc_sources[0] if rtc_sources else None + + if rtc_host and rtc_port and rtc_link and rtc_source: + rtc_url = f"http://{rtc_host}:{rtc_port}/{rtc_link}?src={rtc_source}" + device_config["device"]["configuration_url"] = rtc_url + self.logger.debug(f"Added WebRTC config URL for {device_id}: {rtc_url}") + else: + self.logger.warning( + f"Incomplete WebRTC config for {device_id}: {webrtc_config}" + ) + + except Exception as e: + self.logger.warning( + f"Failed to apply WebRTC config for {device_id}: {e}", exc_info=True + ) # copy the snapshot camera for the eventshot camera, with a couple of changes components[self.get_slug(device_id, 'event_camera')] = { @@ -643,7 +683,7 @@ class AmcrestMqtt(object): self.mqttc.publish(self.get_discovery_topic(device_id, 'config'), payload, qos=self.mqtt_config['qos'], retain=True) - # refresh * all devices ----------------------------------------------------------------------- + # refresh * all devices ----------------------------------------------------------------------- def refresh_storage_all_devices(self): self.logger.info(f'Refreshing storage info for all devices (every {self.storage_update_interval} sec)') @@ -744,9 +784,9 @@ class AmcrestMqtt(object): def handle_service_message(self, attribute, message): match attribute: - case 'storage_refresh': + case "storage_refresh": self.storage_update_interval = message - self.logger.info(f'Updated STORAGE_REFRESH_INTERVAL to be {message}') + self.logger.info(f"Updated STORAGE_REFRESH_INTERVAL to be {message}") case 'snapshot_refresh': self.snapshot_update_interval = message self.logger.info(f'Updated SNAPSHOT_REFRESH_INTERVAL to be {message}') @@ -839,30 +879,85 @@ class AmcrestMqtt(object): # main loop async def main_loop(self): + """Main event loop for Amcrest MQTT service.""" await self.setup_devices() loop = asyncio.get_running_loop() + + # Create async tasks with descriptive names tasks = [ - asyncio.create_task(self.collect_storage_info()), - asyncio.create_task(self.collect_events()), - asyncio.create_task(self.check_event_queue()), - asyncio.create_task(self.collect_snapshots()), + asyncio.create_task( + self.collect_storage_info(), name="collect_storage_info" + ), + asyncio.create_task(self.collect_events(), name="collect_events"), + asyncio.create_task(self.check_event_queue(), name="check_event_queue"), + asyncio.create_task(self.collect_snapshots(), name="collect_snapshots"), ] - # setup signal handling for tasks + # Graceful signal handler + def _signal_handler(signame): + """Immediate, aggressive shutdown handler for Ctrl+C or SIGTERM.""" + self.logger.warning(f"{signame} received — initiating shutdown NOW...") + + self.running = False + + # Cancel *all* asyncio tasks, even those not tracked manually + loop = asyncio.get_event_loop() + for task in asyncio.all_tasks(loop): + if not task.done(): + task.cancel(f"{signame} received") + + # Force-stop ProcessPoolExecutor if present + try: + if hasattr(self, "api") and hasattr(self.api, "executor"): + self.logger.debug("Force-shutting down process pool...") + self.api.executor.shutdown(wait=False, cancel_futures=True) + except Exception as e: + self.logger.debug(f"Error force-stopping process pool: {e}") + + # Stop the loop immediately after a short delay + loop.call_later(0.05, loop.stop) + for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler( - sig, lambda: asyncio.create_task(self._handle_signals(sig.name, loop)) - ) + try: + loop.add_signal_handler(sig, _signal_handler, sig.name) + except NotImplementedError: + # Windows compatibility + self.logger.debug(f"Signal handling not supported on this platform.") try: results = await asyncio.gather(*tasks, return_exceptions=True) - for result in results: - if isinstance(result, Exception): + + # Handle task exceptions individually + for t, result in zip(tasks, results): + if isinstance(result, asyncio.CancelledError): + self.logger.info(f"Task '{t.get_name()}' cancelled.") + elif isinstance(result, Exception): + self.logger.error( + f"Task '{t.get_name()}' raised an exception: {result}", + exc_info=True, + ) self.running = False - self.logger.error(f'Caught exception: {err}', exc_info=True) except asyncio.CancelledError: - exit(1) + self.logger.info("Main loop cancelled; shutting down...") except Exception as err: + self.logger.exception(f"Unhandled exception in main loop: {err}") self.running = False - self.logger.error(f'Caught exception: {err}') \ No newline at end of file + finally: + self.logger.info("All loops terminated, performing final cleanup...") + + try: + # Save final state or cleanup hooks if needed + if hasattr(self, "save_state"): + self.save_state() + except Exception as e: + self.logger.warning(f"Error during save_state: {e}") + + # Disconnect MQTT cleanly + if self.mqttc and self.mqttc.is_connected(): + try: + self.mqttc.disconnect() + except Exception as e: + self.logger.warning(f"Error during MQTT disconnect: {e}") + + self.logger.info("Main loop complete.") diff --git a/app.py b/app.py index 1ab2deb..54caa10 100644 --- a/app.py +++ b/app.py @@ -1,138 +1,61 @@ -# This software is licensed under the MIT License, which allows you to use, -# copy, modify, merge, publish, distribute, and sell copies of the software, -# with the requirement to include the original copyright notice and this -# permission notice in all copies or substantial portions of the software. -# -# The software is provided 'as is', without any warranty. - +#!/usr/bin/env python3 import asyncio import argparse -from amcrest_mqtt import AmcrestMqtt import logging -import os -import sys -import time -from util import * -import yaml - -# Let's go! -version = read_version() - -# Cmd-line args -argparser = argparse.ArgumentParser() -argparser.add_argument( - '-c', - '--config', - required=False, - help='Directory holding config.yaml or full path to config file', -) -args = argparser.parse_args() - -# Setup config from yaml file or env -configpath = args.config or '/config' -try: - if not configpath.endswith('.yaml'): - if not configpath.endswith('/'): - configpath += '/' - configfile = configpath + 'config.yaml' - with open(configfile) as file: - config = yaml.safe_load(file) - config['config_path'] = configpath - config['config_from'] = 'file' -except: - config = { - 'mqtt': { - 'host': os.getenv('MQTT_HOST') or 'localhost', - 'qos': int(os.getenv('MQTT_QOS') or 0), - 'port': int(os.getenv('MQTT_PORT') or 1883), - 'username': os.getenv('MQTT_USERNAME'), - 'password': os.getenv('MQTT_PASSWORD'), # can be None - 'tls_enabled': os.getenv('MQTT_TLS_ENABLED') == 'true', - 'tls_ca_cert': os.getenv('MQTT_TLS_CA_CERT'), - 'tls_cert': os.getenv('MQTT_TLS_CERT'), - 'tls_key': os.getenv('MQTT_TLS_KEY'), - 'prefix': os.getenv('MQTT_PREFIX') or 'amcrest2mqtt', - 'homeassistant': os.getenv('MQTT_HOMEASSISTANT') == True, - 'discovery_prefix': os.getenv('MQTT_DISCOVERY_PREFIX') or 'homeassistant', - }, - 'amcrest': { - 'hosts': os.getenv("AMCREST_HOSTS"), - 'names': os.getenv("AMCREST_NAMES"), - 'port': int(os.getenv("AMCREST_PORT") or 80), - 'username': os.getenv("AMCREST_USERNAME") or "admin", - 'password': os.getenv("AMCREST_PASSWORD"), - 'storage_update_interval': int(os.getenv("STORAGE_UPDATE_INTERVAL") or 900), - 'snapshot_update_interval': int(os.getenv("SNAPSHOT_UPDATE_INTERVAL") or 300), - 'webrtc': { - 'host': os.getenv("AMCREST_WEBRTC_HOST"), - 'port': int(os.getenv("AMCREST_WEBRTC_PORT") or 1984), - 'link': os.getenv("AMCREST_WEBRTC_LINK") or 'stream.html', - 'sources': os.getenv("AMCREST_WEBRTC_SOURCES"), - }, - }, - 'debug': True if os.getenv('DEBUG') else False, - 'hide_ts': True if os.getenv('HIDE_TS') else False, - 'timezone': os.getenv('TZ'), - 'config_from': 'env', - } -config['version'] = version -config['configpath'] = os.path.dirname(configpath) - -# defaults -if 'username' not in config['mqtt']: config['mqtt']['username'] = '' -if 'password' not in config['mqtt']: config['mqtt']['password'] = '' -if 'qos' not in config['mqtt']: config['mqtt']['qos'] = 0 -if 'timezone' not in config: config['timezone'] = 'UTC' -if 'debug' not in config: config['debug'] = os.getenv('DEBUG') or False -if 'hide_ts' not in config: config['hide_ts'] = os.getenv('HIDE_TS') or False - -# init logging, based on config settings -logging.basicConfig( - format = '%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s: %(message)s' if config['hide_ts'] == False else '[%(levelname)s] %(name)s: %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=logging.INFO if config['debug'] == False else logging.DEBUG -) -logger = logging.getLogger(__name__) -logger.info(f'Starting: amcrest2mqtt {version}') -logger.info(f'Config loaded from {config["config_from"]}') - -# Check for required config properties -if config['amcrest']['hosts'] is None: - logger.error('Missing env var: AMCREST_HOSTS or amcrest.hosts in config') - exit(1) -config['amcrest']['host_count'] = len(config['amcrest']['hosts']) - -if config['amcrest']['names'] is None: - logger.error('Missing env var: AMCREST_NAMES or amcrest.names in config') - exit(1) -config['amcrest']['name_count'] = len(config['amcrest']['names']) - -if config['amcrest']['host_count'] != config['amcrest']['name_count']: - logger.error('The AMCREST_HOSTS and AMCREST_NAMES must have the same number of space-delimited hosts/names') - exit(1) -logger.info(f'Found {config["amcrest"]["host_count"]} host(s) defined to monitor') - -if 'webrtc' in config['amcrest']: - webrtc = config['amcrest']['webrtc'] - if 'host' not in webrtc: - logger.error('Missing HOST in webrtc config') - exit(1) - if 'sources' not in webrtc: - logger.error('Missing SOURCES in webrtc config') - exit(1) - config['amcrest']['webrtc_sources_count'] = len(config['amcrest']['webrtc']['sources']) - if config['amcrest']['host_count'] != config['amcrest']['webrtc_sources_count']: - logger.error('The AMCREST_HOSTS and AMCREST_WEBRTC_SOURCES must have the same number of space-delimited hosts/names') - exit(1) - if 'port' not in webrtc: webrtc['port'] = 1984 - if 'link' not in webrtc: webrtc['link'] = 'stream.html' - -if config['amcrest']['password'] is None: - logger.error('Please set the AMCREST_PASSWORD environment variable') - exit(1) - -logger.debug("DEBUG logging is ON") - -# Go! -with AmcrestMqtt(config) as mqtt: - asyncio.run(mqtt.main_loop()) \ No newline at end of file +from amcrest_mqtt import AmcrestMqtt +from util import load_config + +if __name__ == "__main__": + # Parse command-line arguments + argparser = argparse.ArgumentParser() + argparser.add_argument( + "-c", + "--config", + required=False, + help="Directory or file path for config.yaml (defaults to /config/config.yaml)", + ) + args = argparser.parse_args() + + # Load configuration + config = load_config(args.config) + + # Setup logging + logging.basicConfig( + format=( + "%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s: %(message)s" + if not config["hide_ts"] + else "[%(levelname)s] %(name)s: %(message)s" + ), + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.DEBUG if config["debug"] else logging.INFO, + ) + + logger = logging.getLogger(__name__) + logger.info(f"Starting amcrest2mqtt {config['version']}") + logger.info(f"Config loaded from {config['config_from']} ({config['config_path']})") + + # Run main loop safely + try: + with AmcrestMqtt(config) as mqtt: + try: + # Prefer a clean async run, but handle nested event loops + asyncio.run(mqtt.main_loop()) + except RuntimeError as e: + if "asyncio.run() cannot be called from a running event loop" in str(e): + loop = asyncio.get_event_loop() + loop.run_until_complete(mqtt.main_loop()) + else: + raise + except KeyboardInterrupt: + logger.info("Shutdown requested (Ctrl+C). Exiting gracefully...") + except asyncio.CancelledError: + logger.warning("Main loop cancelled.") + except Exception as e: + logger.exception(f"Unhandled exception in main loop: {e}") + finally: + try: + if "mqtt" in locals() and hasattr(mqtt, "api") and mqtt.api: + mqtt.api.shutdown() + except Exception as e: + logger.debug(f"Error during shutdown: {e}") + logger.info("amcrest2mqtt stopped.") \ No newline at end of file diff --git a/util.py b/util.py index b61f84b..d520781 100644 --- a/util.py +++ b/util.py @@ -1,44 +1,121 @@ -# This software is licensed under the MIT License, which allows you to use, -# copy, modify, merge, publish, distribute, and sell copies of the software, -# with the requirement to include the original copyright notice and this -# permission notice in all copies or substantial portions of the software. -# -# The software is provided 'as is', without any warranty. - -import ipaddress +import logging import os import socket +import yaml -# Helper functions and callbacks -def read_file(file_name): - with open(file_name, 'r') as file: - data = file.read().replace('\n', '') +def get_ip_address(hostname: str) -> str: + """ + Resolve a hostname to an IP address (IPv4 or IPv6). + + Returns: + str: The resolved IP address, or the original hostname if resolution fails. + """ + if not hostname: + return hostname + + try: + # Try both IPv4 and IPv6 (AF_UNSPEC) + infos = socket.getaddrinfo( + hostname, None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + # Prefer IPv4 addresses if available + for family, _, _, _, sockaddr in infos: + if family == socket.AF_INET: + return sockaddr[0] + # Otherwise, fallback to first valid IPv6 + return infos[0][4][0] if infos else hostname + except socket.gaierror as e: + logging.debug(f"DNS lookup failed for {hostname}: {e}") + except Exception as e: + logging.debug(f"Unexpected error resolving {hostname}: {e}") + + return hostname + + +def to_gb(bytes_value): + """Convert bytes to a rounded string in gigabytes.""" + try: + gb = float(bytes_value) / (1024**3) + return f"{gb:.2f} GB" + except Exception: + return "0.00 GB" + + +def read_file(file_name, strip_newlines=True, default=None, encoding="utf-8"): + try: + with open(file_name, "r", encoding=encoding) as f: + data = f.read() + return data.replace("\n", "") if strip_newlines else data + except FileNotFoundError: + if default is not None: + return default + raise - return data def read_version(): - if os.path.isfile('./VERSION'): - return read_file('./VERSION') + base_dir = os.path.dirname(os.path.abspath(__file__)) + version_path = os.path.join(base_dir, "VERSION") + try: + with open(version_path, "r") as f: + return f.read().strip() or "unknown" + except FileNotFoundError: + env_version = os.getenv("APP_VERSION") + return env_version.strip() if env_version else "unknown" - return read_file('../VERSION') -def to_gb(total): - return str(round(float(total[0]) / 1024 / 1024 / 1024, 2)) +def load_config(path=None): + """Load and normalize configuration from YAML file or directory.""" -def is_ipv4(string): - try: - ipaddress.IPv4Network(string) - return True - except ValueError: + logger = logging.getLogger(__name__) + default_path = "/config/config.yaml" + + # Resolve config path + config_path = path or default_path + if os.path.isdir(config_path): + config_path = os.path.join(config_path, "config.yaml") + + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + + with open(config_path, "r") as f: + config = yaml.safe_load(f) or {} + + # --- Normalization helpers ------------------------------------------------ + def to_bool(value): + """Coerce common truthy/falsey forms to proper bool.""" + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if value is None: return False + return str(value).strip().lower() in ("true", "1", "yes", "on") -def get_ip_address(string): - if is_ipv4(string): - return string - try: - for i in socket.getaddrinfo(string, None): - if i[0] == socket.AddressFamily.AF_INET: - return i[4][0] - except socket.gaierror as e: - raise Exception(f"Failed to resolve {string}: {e}") - raise Exception(f"Failed to find IP address for {string}") \ No newline at end of file + def normalize_section(section, bool_keys): + """Normalize booleans within a nested section.""" + if not isinstance(config.get(section), dict): + return + for key in bool_keys: + if key in config[section]: + config[section][key] = to_bool(config[section][key]) + + # --- Global defaults + normalization ------------------------------------- + config.setdefault("version", "1.0.0") + config.setdefault("debug", False) + config.setdefault("hide_ts", False) + config.setdefault("timezone", "UTC") + + # normalize top-level flags + config["debug"] = to_bool(config.get("debug")) + config["hide_ts"] = to_bool(config.get("hide_ts")) + + # Example: normalize booleans within sections + normalize_section("mqtt", ["tls", "retain", "clean_session"]) + normalize_section("amcrest", ["webrtc", "verify_ssl"]) + normalize_section("service", ["enabled", "auto_restart"]) + + # Add metadata for debugging/logging + config["config_path"] = os.path.abspath(config_path) + config["config_from"] = "file" if os.path.exists(config_path) else "defaults" + + return config