feat(core): add async process pool, graceful signal handling, and safer config loading

- Added multiprocessing-based connection pooling for faster camera setup
- Implemented robust signal handling (SIGINT/SIGTERM) for graceful shutdown
- Upgraded MQTT client to v5 with improved logging and reconnect behavior
- Added defensive WebRTC config validation to avoid runtime errors
- Introduced unified config loader (`util.load_config`) with boolean coercion
- Improved IP resolution and logging consistency across modules
pull/106/head
Jeff Culverhouse 4 months ago
parent a69ffc1667
commit f025d60f75

@ -5,23 +5,20 @@
#
# The software is provided 'as is', without any warranty.
from amcrest import AmcrestCamera, AmcrestError, CommError, LoginError, exceptions
from amcrest import AmcrestCamera, AmcrestError, CommError, LoginError
import asyncio
from asyncio import timeout
from concurrent.futures import ProcessPoolExecutor
import base64
from datetime import datetime
import httpx
import logging
import os
import time
from util import *
from zoneinfo import ZoneInfo
import signal
from util import get_ip_address, to_gb
class AmcrestAPI(object):
class AmcrestAPI:
def __init__(self, config):
self.logger = logging.getLogger(__name__)
# we don't want to get this mess of deeper-level logging
# Quiet down noisy loggers
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore.http11").setLevel(logging.WARNING)
logging.getLogger("httpcore.connection").setLevel(logging.WARNING)
@ -29,96 +26,149 @@ class AmcrestAPI(object):
logging.getLogger("amcrest.event").setLevel(logging.WARNING)
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)
self.last_call_date = ''
self.timezone = config['timezone']
self.amcrest_config = config['amcrest']
self.count = len(self.amcrest_config['hosts'])
self.timezone = config["timezone"]
self.amcrest_config = config["amcrest"]
self.devices = {}
self.events = []
self.executor = ProcessPoolExecutor(
max_workers=min(8, len(self.amcrest_config["hosts"]))
)
# handle signals gracefully
signal.signal(signal.SIGINT, self._sig_handler)
signal.signal(signal.SIGTERM, self._sig_handler)
self._shutting_down = False
def _sig_handler(self, signum, frame):
if not self._shutting_down:
self._shutting_down = True
self.logger.warning("SIGINT received — shutting down process pool...")
self.shutdown()
def shutdown(self):
"""Instantly kill process pool workers."""
try:
if self.executor:
self.logger.debug("Force-terminating process pool workers...")
self.executor.shutdown(wait=False, cancel_futures=True)
self.executor = None
except Exception as e:
self.logger.warning(f"Error shutting down process pool: {e}")
def get_camera(self, host):
cfg = self.amcrest_config
return AmcrestCamera(
host, cfg["port"], cfg["username"], cfg["password"], verbose=False
).camera
# ----------------------------------------------------------------------------------------------
async def connect_to_devices(self):
self.logger.info(f'Connecting to: {self.amcrest_config["hosts"]}')
self.logger.info(f"Connecting to: {self.amcrest_config['hosts']}")
tasks = []
for host in self.amcrest_config['hosts']:
device_name = self.amcrest_config['names'].pop(0)
task = asyncio.create_task(self.get_device(host, device_name))
tasks.append(task)
await asyncio.gather(*tasks, return_exceptions=True)
# Defensive guard against shutdown signals
if getattr(self, "shutting_down", False):
self.logger.warning("Connect aborted: shutdown already in progress.")
return {}
if len(self.devices) == 0:
self.logger.error('Failed to connect to all devices, exiting')
tasks = [
asyncio.create_task(self._connect_device_threaded(host, name))
for host, name in zip(
self.amcrest_config["hosts"], self.amcrest_config["names"]
)
]
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
except asyncio.CancelledError:
self.logger.warning("Device connection cancelled by signal.")
return {}
except Exception as e:
self.logger.error(f"Device connection failed: {e}", exc_info=True)
return {}
successes = [r for r in results if isinstance(r, dict) and "config" in r]
failures = [r for r in results if isinstance(r, dict) and "error" in r]
self.logger.info(
f"Device connection summary: {len(successes)} succeeded, {len(failures)} failed"
)
if not successes and not getattr(self, "shutting_down", False):
self.logger.error("Failed to connect to any devices, exiting")
exit(1)
# return just the config of each device, not the camera object
return {d: self.devices[d]['config'] for d in self.devices.keys()}
# Recreate cameras in parent process
for info in successes:
cfg = info["config"]
serial = cfg["serial_number"]
cam = self.get_camera(cfg["host_ip"])
self.devices[serial] = {"camera": cam, "config": cfg}
def get_camera(self, host):
config = self.amcrest_config
return AmcrestCamera(host, config['port'], config['username'], config['password'], verbose=False).camera
return {d: self.devices[d]["config"] for d in self.devices.keys()}
async def _connect_device_threaded(self, host, device_name):
"""Run the blocking camera connection logic in a separate process."""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
self.executor, _connect_device_worker, (host, device_name, self.amcrest_config)
)
async def get_device(self, host, device_name):
def _connect_device_sync(self, host, device_name):
"""Blocking version of connect logic that runs in a separate process."""
try:
# resolve host and setup camera by ip so we aren't making 100k DNS lookups per day
try:
host_ip = get_ip_address(host)
self.logger.info(f'nslookup {host} got us {host_ip}')
camera = self.get_camera(host_ip)
except Exception as err:
self.logger.error(f'Error with {host}: {err}')
device_type = camera.device_type.replace('type=', '').strip()
is_ad110 = device_type == 'AD110'
is_ad410 = device_type == 'AD410'
import multiprocessing
p_name = multiprocessing.current_process().name
host_ip = get_ip_address(host)
camera = self.get_camera(host_ip)
device_type = camera.device_type.replace("type=", "").strip()
is_ad110 = device_type == "AD110"
is_ad410 = device_type == "AD410"
is_doorbell = is_ad110 or is_ad410
serial_number = camera.serial_number
if not isinstance(serial_number, str):
self.logger.error(f'Error fetching serial number for {host}: {camera.serial_number}')
exit(1)
version = camera.software_information[0].replace('version=', '').strip()
version = camera.software_information[0].replace("version=", "").strip()
build = camera.software_information[1].strip()
sw_version = f'{version} ({build})'
network_config = dict(item.split('=') for item in camera.network_config.splitlines())
interface = network_config['table.Network.DefaultInterface']
ip_address = network_config[f'table.Network.{interface}.IPAddress']
mac_address = network_config[f'table.Network.{interface}.PhysicalAddress'].upper()
action = 'Connected' if camera.serial_number not in self.devices else 'Reconnected'
self.logger.info(f'{action} to {host} as {camera.serial_number}')
self.devices[serial_number] = {
'camera': camera,
'config': {
'host': host,
'host_ip': host_ip,
'device_name': device_name,
'device_type': device_type,
'device_class': camera.device_class,
'is_ad110': is_ad110,
'is_ad410': is_ad410,
'is_doorbell': is_doorbell,
'serial_number': serial_number,
'software_version': sw_version,
'hardware_version': camera.hardware_version,
'vendor': camera.vendor_information,
'network': {
'interface': interface,
'ip_address': ip_address,
'mac': mac_address,
}
},
sw_version = f"{version} ({build})"
network_config = dict(
item.split("=") for item in camera.network_config.splitlines()
)
iface = network_config["table.Network.DefaultInterface"]
ip_address = network_config[f"table.Network.{iface}.IPAddress"]
mac_address = network_config[f"table.Network.{iface}.PhysicalAddress"].upper()
print(f"[{p_name}] Connected to {host} ({ip_address}) as {serial_number}")
return {
"config": {
"host": host,
"host_ip": host_ip,
"device_name": device_name,
"device_type": device_type,
"device_class": camera.device_class,
"is_ad110": is_ad110,
"is_ad410": is_ad410,
"is_doorbell": is_doorbell,
"serial_number": serial_number,
"software_version": sw_version,
"hardware_version": camera.hardware_version,
"vendor": camera.vendor_information,
"network": {
"interface": iface,
"ip_address": ip_address,
"mac": mac_address,
},
}
}
self.get_privacy_mode(serial_number)
except LoginError as err:
self.logger.error(f'Invalid username/password to connect to device "{host}", fix in config.yaml')
except AmcrestError as err:
self.logger.error(f'Failed to connect to device "{host}", check config.yaml and restart to try again: {err}')
except Exception as e:
import traceback
err_trace = traceback.format_exc()
print(f"[child] Error connecting to {host}: {e}\n{err_trace}")
return {"error": f"{e}", "host": host}
# Storage stats -------------------------------------------------------------------------------
@ -152,7 +202,6 @@ class AmcrestAPI(object):
return privacy_mode
def set_privacy_mode(self, device_id, switch):
device = self.devices[device_id]
@ -245,7 +294,6 @@ class AmcrestAPI(object):
if tries == 3:
self.logger.error(f'Failed to communicate with device ({device_id}) to get recorded file')
# Events --------------------------------------------------------------------------------------
async def collect_all_device_events(self):
@ -321,3 +369,69 @@ class AmcrestAPI(object):
def get_next_event(self):
return self.events.pop(0) if len(self.events) > 0 else None
def _connect_device_worker(args):
"""Top-level helper so it can be pickled by ProcessPoolExecutor."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
host, device_name, amcrest_cfg = args
try:
host_ip = get_ip_address(host)
camera = AmcrestCamera(
host_ip,
amcrest_cfg["port"],
amcrest_cfg["username"],
amcrest_cfg["password"],
verbose=False,
).camera
device_type = camera.device_type.replace("type=", "").strip()
is_ad110 = device_type == "AD110"
is_ad410 = device_type == "AD410"
is_doorbell = is_ad110 or is_ad410
serial_number = camera.serial_number
version = camera.software_information[0].replace("version=", "").strip()
build = camera.software_information[1].strip()
sw_version = f"{version} ({build})"
network_config = dict(
item.split("=") for item in camera.network_config.splitlines()
)
iface = network_config["table.Network.DefaultInterface"]
ip_address = network_config[f"table.Network.{iface}.IPAddress"]
mac_address = network_config[f"table.Network.{iface}.PhysicalAddress"].upper()
print(f"[worker] Connected to {host} ({ip_address}) as {serial_number}")
return {
"config": {
"host": host,
"host_ip": host_ip,
"device_name": device_name,
"device_type": device_type,
"device_class": camera.device_class,
"is_ad110": is_ad110,
"is_ad410": is_ad410,
"is_doorbell": is_doorbell,
"serial_number": serial_number,
"software_version": sw_version,
"hardware_version": camera.hardware_version,
"vendor": camera.vendor_information,
"network": {
"interface": iface,
"ip_address": ip_address,
"mac": mac_address,
},
}
}
except Exception as e:
import traceback
err_trace = traceback.format_exc()
print(f"[worker] Error connecting to {host}: {e}\n{err_trace}")
return {"error": str(e), "host": host}

@ -63,18 +63,19 @@ class AmcrestMqtt(object):
# MQTT Functions ------------------------------------------------------------------------------
def mqtt_on_connect(self, client, userdata, flags, rc, properties):
if rc != 0:
self.logger.error(f'MQTT connection issue ({rc})')
exit()
def mqtt_on_connect(self, client, userdata, flags, reason_code, properties):
if reason_code.value != 0:
self.logger.error(f'MQTT connection issue ({reason_code.getName()})')
self.running = False
return
self.logger.info(f'MQTT connected as {self.client_id}')
client.subscribe("homeassistant/status")
client.subscribe(self.get_device_sub_topic())
client.subscribe(self.get_attribute_sub_topic())
def mqtt_on_disconnect(self, client, userdata, flags, rc, properties):
self.logger.info('MQTT connection closed')
def mqtt_on_disconnect(self, client, userdata, disconnect_flags, reason_code, properties):
self.logger.warning(f'MQTT disconnected: {reason_code.getName()} (flags={disconnect_flags})')
self.mqttc.loop_stop()
if self.running and time.time() > self.mqtt_connect_time + 10:
@ -88,7 +89,6 @@ class AmcrestMqtt(object):
self.paused = False
else:
self.running = False
exit()
def mqtt_on_log(self, client, userdata, paho_log_level, msg):
if paho_log_level == mqtt.LogLevel.MQTT_LOG_ERR:
@ -169,10 +169,10 @@ class AmcrestMqtt(object):
def mqttc_create(self):
self.mqttc = mqtt.Client(
callback_api_version=mqtt.CallbackAPIVersion.VERSION2,
client_id=self.client_id,
clean_session=False,
callback_api_version=mqtt.CallbackAPIVersion.VERSION2,
reconnect_on_failure=False,
protocol=mqtt.MQTTv5,
)
if self.mqtt_config.get('tls_enabled'):
@ -183,7 +183,8 @@ class AmcrestMqtt(object):
cert_reqs=ssl.CERT_REQUIRED,
tls_version=ssl.PROTOCOL_TLS,
)
else:
self.mqttc.tls_insecure_set(self.mqtt_config.get("tls_insecure", False))
if self.mqtt_config.get('username'):
self.mqttc.username_pw_set(
username=self.mqtt_config.get('username'),
password=self.mqtt_config.get('password'),
@ -199,16 +200,24 @@ class AmcrestMqtt(object):
self.mqttc.will_set(self.get_discovery_topic('service', 'availability'), payload="offline", qos=self.mqtt_config['qos'], retain=True)
try:
self.logger.info(
f"Connecting to MQTT broker at {self.mqtt_config.get('host')}:{self.mqtt_config.get('port')} "
f"as {self.client_id}"
)
self.mqttc.connect(
self.mqtt_config.get('host'),
host=self.mqtt_config.get('host'),
port=self.mqtt_config.get('port'),
keepalive=60,
)
self.mqtt_connect_time = time.time()
self.mqttc.loop_start()
except ConnectionError as error:
self.logger.error(f'COULD NOT CONNECT TO MQTT {self.mqtt_config.get("host")}: {error}')
exit(1)
except Exception as error:
self.logger.error(
f"Failed to connect to MQTT broker {self.mqtt_config.get('host')}:{self.mqtt_config.get('port')} "
f"({type(error).__name__}: {error})",
exc_info=True,
)
self.running = False
# MQTT Topics ---------------------------------------------------------------------------------
@ -261,6 +270,18 @@ class AmcrestMqtt(object):
return f"{self.mqtt_config['prefix']}/{self.get_component_slug(device_id)}/{topic}/{subtopic}"
return f"{self.mqtt_config['discovery_prefix']}/device/{self.get_component_slug(device_id)}/{topic}/{subtopic}"
def ha_cfg_topic(self, domain: str, object_id: str) -> str:
dp = self.mqtt_config.get('discovery_prefix', 'homeassistant')
return f"{dp}/{domain}/{object_id}/config"
def service_oid(self, suffix: str) -> str:
return f"{self.service_slug}_{suffix}"
def svc_topic(self, sub: str) -> str:
# Runtime topics (your own prefix), e.g. govee2mqtt/govee-service/state
pfx = self.mqtt_config.get('prefix', 'govee2mqtt')
return f"{pfx}/amcrest-service/{sub}"
# Service Device ------------------------------------------------------------------------------
def publish_service_state(self):
@ -474,14 +495,33 @@ class AmcrestMqtt(object):
'value_template': '{{ value_json.state }}',
'unique_id': self.get_slug(device_id, 'snapshot_camera'),
}
if 'webrtc' in self.amcrest_config:
webrtc_config = self.amcrest_config['webrtc']
rtc_host = webrtc_config['host']
rtc_port = webrtc_config['port']
rtc_link = webrtc_config['link']
rtc_source = webrtc_config['sources'].pop(0)
rtc_url = f'http://{rtc_host}:{rtc_port}/{rtc_link}?src={rtc_source}'
device_config['device']['configuration_url'] = rtc_url
# --- Safe WebRTC config handling ----------------------------------------
webrtc_config = self.amcrest_config.get("webrtc")
# Handle missing, boolean, or incomplete configs gracefully
if isinstance(webrtc_config, bool) or not webrtc_config:
self.logger.debug("No valid WebRTC config found; skipping WebRTC setup.")
else:
try:
rtc_host = webrtc_config.get("host")
rtc_port = webrtc_config.get("port")
rtc_link = webrtc_config.get("link")
rtc_sources = webrtc_config.get("sources", [])
rtc_source = rtc_sources[0] if rtc_sources else None
if rtc_host and rtc_port and rtc_link and rtc_source:
rtc_url = f"http://{rtc_host}:{rtc_port}/{rtc_link}?src={rtc_source}"
device_config["device"]["configuration_url"] = rtc_url
self.logger.debug(f"Added WebRTC config URL for {device_id}: {rtc_url}")
else:
self.logger.warning(
f"Incomplete WebRTC config for {device_id}: {webrtc_config}"
)
except Exception as e:
self.logger.warning(
f"Failed to apply WebRTC config for {device_id}: {e}", exc_info=True
)
# copy the snapshot camera for the eventshot camera, with a couple of changes
components[self.get_slug(device_id, 'event_camera')] = {
@ -643,7 +683,7 @@ class AmcrestMqtt(object):
self.mqttc.publish(self.get_discovery_topic(device_id, 'config'), payload, qos=self.mqtt_config['qos'], retain=True)
# refresh * all devices -----------------------------------------------------------------------
# refresh * all devices -----------------------------------------------------------------------
def refresh_storage_all_devices(self):
self.logger.info(f'Refreshing storage info for all devices (every {self.storage_update_interval} sec)')
@ -744,9 +784,9 @@ class AmcrestMqtt(object):
def handle_service_message(self, attribute, message):
match attribute:
case 'storage_refresh':
case "storage_refresh":
self.storage_update_interval = message
self.logger.info(f'Updated STORAGE_REFRESH_INTERVAL to be {message}')
self.logger.info(f"Updated STORAGE_REFRESH_INTERVAL to be {message}")
case 'snapshot_refresh':
self.snapshot_update_interval = message
self.logger.info(f'Updated SNAPSHOT_REFRESH_INTERVAL to be {message}')
@ -839,30 +879,85 @@ class AmcrestMqtt(object):
# main loop
async def main_loop(self):
"""Main event loop for Amcrest MQTT service."""
await self.setup_devices()
loop = asyncio.get_running_loop()
# Create async tasks with descriptive names
tasks = [
asyncio.create_task(self.collect_storage_info()),
asyncio.create_task(self.collect_events()),
asyncio.create_task(self.check_event_queue()),
asyncio.create_task(self.collect_snapshots()),
asyncio.create_task(
self.collect_storage_info(), name="collect_storage_info"
),
asyncio.create_task(self.collect_events(), name="collect_events"),
asyncio.create_task(self.check_event_queue(), name="check_event_queue"),
asyncio.create_task(self.collect_snapshots(), name="collect_snapshots"),
]
# setup signal handling for tasks
# Graceful signal handler
def _signal_handler(signame):
"""Immediate, aggressive shutdown handler for Ctrl+C or SIGTERM."""
self.logger.warning(f"{signame} received — initiating shutdown NOW...")
self.running = False
# Cancel *all* asyncio tasks, even those not tracked manually
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
if not task.done():
task.cancel(f"{signame} received")
# Force-stop ProcessPoolExecutor if present
try:
if hasattr(self, "api") and hasattr(self.api, "executor"):
self.logger.debug("Force-shutting down process pool...")
self.api.executor.shutdown(wait=False, cancel_futures=True)
except Exception as e:
self.logger.debug(f"Error force-stopping process pool: {e}")
# Stop the loop immediately after a short delay
loop.call_later(0.05, loop.stop)
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(
sig, lambda: asyncio.create_task(self._handle_signals(sig.name, loop))
)
try:
loop.add_signal_handler(sig, _signal_handler, sig.name)
except NotImplementedError:
# Windows compatibility
self.logger.debug(f"Signal handling not supported on this platform.")
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
# Handle task exceptions individually
for t, result in zip(tasks, results):
if isinstance(result, asyncio.CancelledError):
self.logger.info(f"Task '{t.get_name()}' cancelled.")
elif isinstance(result, Exception):
self.logger.error(
f"Task '{t.get_name()}' raised an exception: {result}",
exc_info=True,
)
self.running = False
self.logger.error(f'Caught exception: {err}', exc_info=True)
except asyncio.CancelledError:
exit(1)
self.logger.info("Main loop cancelled; shutting down...")
except Exception as err:
self.logger.exception(f"Unhandled exception in main loop: {err}")
self.running = False
self.logger.error(f'Caught exception: {err}')
finally:
self.logger.info("All loops terminated, performing final cleanup...")
try:
# Save final state or cleanup hooks if needed
if hasattr(self, "save_state"):
self.save_state()
except Exception as e:
self.logger.warning(f"Error during save_state: {e}")
# Disconnect MQTT cleanly
if self.mqttc and self.mqttc.is_connected():
try:
self.mqttc.disconnect()
except Exception as e:
self.logger.warning(f"Error during MQTT disconnect: {e}")
self.logger.info("Main loop complete.")

193
app.py

@ -1,138 +1,61 @@
# This software is licensed under the MIT License, which allows you to use,
# copy, modify, merge, publish, distribute, and sell copies of the software,
# with the requirement to include the original copyright notice and this
# permission notice in all copies or substantial portions of the software.
#
# The software is provided 'as is', without any warranty.
#!/usr/bin/env python3
import asyncio
import argparse
from amcrest_mqtt import AmcrestMqtt
import logging
import os
import sys
import time
from util import *
import yaml
# Let's go!
version = read_version()
# Cmd-line args
argparser = argparse.ArgumentParser()
argparser.add_argument(
'-c',
'--config',
required=False,
help='Directory holding config.yaml or full path to config file',
)
args = argparser.parse_args()
# Setup config from yaml file or env
configpath = args.config or '/config'
try:
if not configpath.endswith('.yaml'):
if not configpath.endswith('/'):
configpath += '/'
configfile = configpath + 'config.yaml'
with open(configfile) as file:
config = yaml.safe_load(file)
config['config_path'] = configpath
config['config_from'] = 'file'
except:
config = {
'mqtt': {
'host': os.getenv('MQTT_HOST') or 'localhost',
'qos': int(os.getenv('MQTT_QOS') or 0),
'port': int(os.getenv('MQTT_PORT') or 1883),
'username': os.getenv('MQTT_USERNAME'),
'password': os.getenv('MQTT_PASSWORD'), # can be None
'tls_enabled': os.getenv('MQTT_TLS_ENABLED') == 'true',
'tls_ca_cert': os.getenv('MQTT_TLS_CA_CERT'),
'tls_cert': os.getenv('MQTT_TLS_CERT'),
'tls_key': os.getenv('MQTT_TLS_KEY'),
'prefix': os.getenv('MQTT_PREFIX') or 'amcrest2mqtt',
'homeassistant': os.getenv('MQTT_HOMEASSISTANT') == True,
'discovery_prefix': os.getenv('MQTT_DISCOVERY_PREFIX') or 'homeassistant',
},
'amcrest': {
'hosts': os.getenv("AMCREST_HOSTS"),
'names': os.getenv("AMCREST_NAMES"),
'port': int(os.getenv("AMCREST_PORT") or 80),
'username': os.getenv("AMCREST_USERNAME") or "admin",
'password': os.getenv("AMCREST_PASSWORD"),
'storage_update_interval': int(os.getenv("STORAGE_UPDATE_INTERVAL") or 900),
'snapshot_update_interval': int(os.getenv("SNAPSHOT_UPDATE_INTERVAL") or 300),
'webrtc': {
'host': os.getenv("AMCREST_WEBRTC_HOST"),
'port': int(os.getenv("AMCREST_WEBRTC_PORT") or 1984),
'link': os.getenv("AMCREST_WEBRTC_LINK") or 'stream.html',
'sources': os.getenv("AMCREST_WEBRTC_SOURCES"),
},
},
'debug': True if os.getenv('DEBUG') else False,
'hide_ts': True if os.getenv('HIDE_TS') else False,
'timezone': os.getenv('TZ'),
'config_from': 'env',
}
config['version'] = version
config['configpath'] = os.path.dirname(configpath)
# defaults
if 'username' not in config['mqtt']: config['mqtt']['username'] = ''
if 'password' not in config['mqtt']: config['mqtt']['password'] = ''
if 'qos' not in config['mqtt']: config['mqtt']['qos'] = 0
if 'timezone' not in config: config['timezone'] = 'UTC'
if 'debug' not in config: config['debug'] = os.getenv('DEBUG') or False
if 'hide_ts' not in config: config['hide_ts'] = os.getenv('HIDE_TS') or False
# init logging, based on config settings
logging.basicConfig(
format = '%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s: %(message)s' if config['hide_ts'] == False else '[%(levelname)s] %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO if config['debug'] == False else logging.DEBUG
)
logger = logging.getLogger(__name__)
logger.info(f'Starting: amcrest2mqtt {version}')
logger.info(f'Config loaded from {config["config_from"]}')
# Check for required config properties
if config['amcrest']['hosts'] is None:
logger.error('Missing env var: AMCREST_HOSTS or amcrest.hosts in config')
exit(1)
config['amcrest']['host_count'] = len(config['amcrest']['hosts'])
if config['amcrest']['names'] is None:
logger.error('Missing env var: AMCREST_NAMES or amcrest.names in config')
exit(1)
config['amcrest']['name_count'] = len(config['amcrest']['names'])
if config['amcrest']['host_count'] != config['amcrest']['name_count']:
logger.error('The AMCREST_HOSTS and AMCREST_NAMES must have the same number of space-delimited hosts/names')
exit(1)
logger.info(f'Found {config["amcrest"]["host_count"]} host(s) defined to monitor')
if 'webrtc' in config['amcrest']:
webrtc = config['amcrest']['webrtc']
if 'host' not in webrtc:
logger.error('Missing HOST in webrtc config')
exit(1)
if 'sources' not in webrtc:
logger.error('Missing SOURCES in webrtc config')
exit(1)
config['amcrest']['webrtc_sources_count'] = len(config['amcrest']['webrtc']['sources'])
if config['amcrest']['host_count'] != config['amcrest']['webrtc_sources_count']:
logger.error('The AMCREST_HOSTS and AMCREST_WEBRTC_SOURCES must have the same number of space-delimited hosts/names')
exit(1)
if 'port' not in webrtc: webrtc['port'] = 1984
if 'link' not in webrtc: webrtc['link'] = 'stream.html'
if config['amcrest']['password'] is None:
logger.error('Please set the AMCREST_PASSWORD environment variable')
exit(1)
logger.debug("DEBUG logging is ON")
# Go!
with AmcrestMqtt(config) as mqtt:
asyncio.run(mqtt.main_loop())
from amcrest_mqtt import AmcrestMqtt
from util import load_config
if __name__ == "__main__":
# Parse command-line arguments
argparser = argparse.ArgumentParser()
argparser.add_argument(
"-c",
"--config",
required=False,
help="Directory or file path for config.yaml (defaults to /config/config.yaml)",
)
args = argparser.parse_args()
# Load configuration
config = load_config(args.config)
# Setup logging
logging.basicConfig(
format=(
"%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s: %(message)s"
if not config["hide_ts"]
else "[%(levelname)s] %(name)s: %(message)s"
),
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.DEBUG if config["debug"] else logging.INFO,
)
logger = logging.getLogger(__name__)
logger.info(f"Starting amcrest2mqtt {config['version']}")
logger.info(f"Config loaded from {config['config_from']} ({config['config_path']})")
# Run main loop safely
try:
with AmcrestMqtt(config) as mqtt:
try:
# Prefer a clean async run, but handle nested event loops
asyncio.run(mqtt.main_loop())
except RuntimeError as e:
if "asyncio.run() cannot be called from a running event loop" in str(e):
loop = asyncio.get_event_loop()
loop.run_until_complete(mqtt.main_loop())
else:
raise
except KeyboardInterrupt:
logger.info("Shutdown requested (Ctrl+C). Exiting gracefully...")
except asyncio.CancelledError:
logger.warning("Main loop cancelled.")
except Exception as e:
logger.exception(f"Unhandled exception in main loop: {e}")
finally:
try:
if "mqtt" in locals() and hasattr(mqtt, "api") and mqtt.api:
mqtt.api.shutdown()
except Exception as e:
logger.debug(f"Error during shutdown: {e}")
logger.info("amcrest2mqtt stopped.")

@ -1,44 +1,121 @@
# This software is licensed under the MIT License, which allows you to use,
# copy, modify, merge, publish, distribute, and sell copies of the software,
# with the requirement to include the original copyright notice and this
# permission notice in all copies or substantial portions of the software.
#
# The software is provided 'as is', without any warranty.
import ipaddress
import logging
import os
import socket
import yaml
# Helper functions and callbacks
def read_file(file_name):
with open(file_name, 'r') as file:
data = file.read().replace('\n', '')
def get_ip_address(hostname: str) -> str:
"""
Resolve a hostname to an IP address (IPv4 or IPv6).
Returns:
str: The resolved IP address, or the original hostname if resolution fails.
"""
if not hostname:
return hostname
try:
# Try both IPv4 and IPv6 (AF_UNSPEC)
infos = socket.getaddrinfo(
hostname, None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
# Prefer IPv4 addresses if available
for family, _, _, _, sockaddr in infos:
if family == socket.AF_INET:
return sockaddr[0]
# Otherwise, fallback to first valid IPv6
return infos[0][4][0] if infos else hostname
except socket.gaierror as e:
logging.debug(f"DNS lookup failed for {hostname}: {e}")
except Exception as e:
logging.debug(f"Unexpected error resolving {hostname}: {e}")
return hostname
def to_gb(bytes_value):
"""Convert bytes to a rounded string in gigabytes."""
try:
gb = float(bytes_value) / (1024**3)
return f"{gb:.2f} GB"
except Exception:
return "0.00 GB"
def read_file(file_name, strip_newlines=True, default=None, encoding="utf-8"):
try:
with open(file_name, "r", encoding=encoding) as f:
data = f.read()
return data.replace("\n", "") if strip_newlines else data
except FileNotFoundError:
if default is not None:
return default
raise
return data
def read_version():
if os.path.isfile('./VERSION'):
return read_file('./VERSION')
base_dir = os.path.dirname(os.path.abspath(__file__))
version_path = os.path.join(base_dir, "VERSION")
try:
with open(version_path, "r") as f:
return f.read().strip() or "unknown"
except FileNotFoundError:
env_version = os.getenv("APP_VERSION")
return env_version.strip() if env_version else "unknown"
return read_file('../VERSION')
def to_gb(total):
return str(round(float(total[0]) / 1024 / 1024 / 1024, 2))
def load_config(path=None):
"""Load and normalize configuration from YAML file or directory."""
def is_ipv4(string):
try:
ipaddress.IPv4Network(string)
return True
except ValueError:
logger = logging.getLogger(__name__)
default_path = "/config/config.yaml"
# Resolve config path
config_path = path or default_path
if os.path.isdir(config_path):
config_path = os.path.join(config_path, "config.yaml")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(config_path, "r") as f:
config = yaml.safe_load(f) or {}
# --- Normalization helpers ------------------------------------------------
def to_bool(value):
"""Coerce common truthy/falsey forms to proper bool."""
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if value is None:
return False
return str(value).strip().lower() in ("true", "1", "yes", "on")
def get_ip_address(string):
if is_ipv4(string):
return string
try:
for i in socket.getaddrinfo(string, None):
if i[0] == socket.AddressFamily.AF_INET:
return i[4][0]
except socket.gaierror as e:
raise Exception(f"Failed to resolve {string}: {e}")
raise Exception(f"Failed to find IP address for {string}")
def normalize_section(section, bool_keys):
"""Normalize booleans within a nested section."""
if not isinstance(config.get(section), dict):
return
for key in bool_keys:
if key in config[section]:
config[section][key] = to_bool(config[section][key])
# --- Global defaults + normalization -------------------------------------
config.setdefault("version", "1.0.0")
config.setdefault("debug", False)
config.setdefault("hide_ts", False)
config.setdefault("timezone", "UTC")
# normalize top-level flags
config["debug"] = to_bool(config.get("debug"))
config["hide_ts"] = to_bool(config.get("hide_ts"))
# Example: normalize booleans within sections
normalize_section("mqtt", ["tls", "retain", "clean_session"])
normalize_section("amcrest", ["webrtc", "verify_ssl"])
normalize_section("service", ["enabled", "auto_restart"])
# Add metadata for debugging/logging
config["config_path"] = os.path.abspath(config_path)
config["config_from"] = "file" if os.path.exists(config_path) else "defaults"
return config

Loading…
Cancel
Save