feat(core): add async process pool, graceful signal handling, and safer config loading
- Added multiprocessing-based connection pooling for faster camera setup - Implemented robust signal handling (SIGINT/SIGTERM) for graceful shutdown - Upgraded MQTT client to v5 with improved logging and reconnect behavior - Added defensive WebRTC config validation to avoid runtime errors - Introduced unified config loader (`util.load_config`) with boolean coercion - Improved IP resolution and logging consistency across modulespull/106/head
parent
a69ffc1667
commit
f025d60f75
@ -1,138 +1,61 @@
|
|||||||
# This software is licensed under the MIT License, which allows you to use,
|
#!/usr/bin/env python3
|
||||||
# copy, modify, merge, publish, distribute, and sell copies of the software,
|
|
||||||
# with the requirement to include the original copyright notice and this
|
|
||||||
# permission notice in all copies or substantial portions of the software.
|
|
||||||
#
|
|
||||||
# The software is provided 'as is', without any warranty.
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import argparse
|
import argparse
|
||||||
from amcrest_mqtt import AmcrestMqtt
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
from amcrest_mqtt import AmcrestMqtt
|
||||||
import sys
|
from util import load_config
|
||||||
import time
|
|
||||||
from util import *
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
# Let's go!
|
|
||||||
version = read_version()
|
|
||||||
|
|
||||||
# Cmd-line args
|
if __name__ == "__main__":
|
||||||
|
# Parse command-line arguments
|
||||||
argparser = argparse.ArgumentParser()
|
argparser = argparse.ArgumentParser()
|
||||||
argparser.add_argument(
|
argparser.add_argument(
|
||||||
'-c',
|
"-c",
|
||||||
'--config',
|
"--config",
|
||||||
required=False,
|
required=False,
|
||||||
help='Directory holding config.yaml or full path to config file',
|
help="Directory or file path for config.yaml (defaults to /config/config.yaml)",
|
||||||
)
|
)
|
||||||
args = argparser.parse_args()
|
args = argparser.parse_args()
|
||||||
|
|
||||||
# Setup config from yaml file or env
|
# Load configuration
|
||||||
configpath = args.config or '/config'
|
config = load_config(args.config)
|
||||||
try:
|
|
||||||
if not configpath.endswith('.yaml'):
|
|
||||||
if not configpath.endswith('/'):
|
|
||||||
configpath += '/'
|
|
||||||
configfile = configpath + 'config.yaml'
|
|
||||||
with open(configfile) as file:
|
|
||||||
config = yaml.safe_load(file)
|
|
||||||
config['config_path'] = configpath
|
|
||||||
config['config_from'] = 'file'
|
|
||||||
except:
|
|
||||||
config = {
|
|
||||||
'mqtt': {
|
|
||||||
'host': os.getenv('MQTT_HOST') or 'localhost',
|
|
||||||
'qos': int(os.getenv('MQTT_QOS') or 0),
|
|
||||||
'port': int(os.getenv('MQTT_PORT') or 1883),
|
|
||||||
'username': os.getenv('MQTT_USERNAME'),
|
|
||||||
'password': os.getenv('MQTT_PASSWORD'), # can be None
|
|
||||||
'tls_enabled': os.getenv('MQTT_TLS_ENABLED') == 'true',
|
|
||||||
'tls_ca_cert': os.getenv('MQTT_TLS_CA_CERT'),
|
|
||||||
'tls_cert': os.getenv('MQTT_TLS_CERT'),
|
|
||||||
'tls_key': os.getenv('MQTT_TLS_KEY'),
|
|
||||||
'prefix': os.getenv('MQTT_PREFIX') or 'amcrest2mqtt',
|
|
||||||
'homeassistant': os.getenv('MQTT_HOMEASSISTANT') == True,
|
|
||||||
'discovery_prefix': os.getenv('MQTT_DISCOVERY_PREFIX') or 'homeassistant',
|
|
||||||
},
|
|
||||||
'amcrest': {
|
|
||||||
'hosts': os.getenv("AMCREST_HOSTS"),
|
|
||||||
'names': os.getenv("AMCREST_NAMES"),
|
|
||||||
'port': int(os.getenv("AMCREST_PORT") or 80),
|
|
||||||
'username': os.getenv("AMCREST_USERNAME") or "admin",
|
|
||||||
'password': os.getenv("AMCREST_PASSWORD"),
|
|
||||||
'storage_update_interval': int(os.getenv("STORAGE_UPDATE_INTERVAL") or 900),
|
|
||||||
'snapshot_update_interval': int(os.getenv("SNAPSHOT_UPDATE_INTERVAL") or 300),
|
|
||||||
'webrtc': {
|
|
||||||
'host': os.getenv("AMCREST_WEBRTC_HOST"),
|
|
||||||
'port': int(os.getenv("AMCREST_WEBRTC_PORT") or 1984),
|
|
||||||
'link': os.getenv("AMCREST_WEBRTC_LINK") or 'stream.html',
|
|
||||||
'sources': os.getenv("AMCREST_WEBRTC_SOURCES"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
'debug': True if os.getenv('DEBUG') else False,
|
|
||||||
'hide_ts': True if os.getenv('HIDE_TS') else False,
|
|
||||||
'timezone': os.getenv('TZ'),
|
|
||||||
'config_from': 'env',
|
|
||||||
}
|
|
||||||
config['version'] = version
|
|
||||||
config['configpath'] = os.path.dirname(configpath)
|
|
||||||
|
|
||||||
# defaults
|
|
||||||
if 'username' not in config['mqtt']: config['mqtt']['username'] = ''
|
|
||||||
if 'password' not in config['mqtt']: config['mqtt']['password'] = ''
|
|
||||||
if 'qos' not in config['mqtt']: config['mqtt']['qos'] = 0
|
|
||||||
if 'timezone' not in config: config['timezone'] = 'UTC'
|
|
||||||
if 'debug' not in config: config['debug'] = os.getenv('DEBUG') or False
|
|
||||||
if 'hide_ts' not in config: config['hide_ts'] = os.getenv('HIDE_TS') or False
|
|
||||||
|
|
||||||
# init logging, based on config settings
|
# Setup logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format = '%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s: %(message)s' if config['hide_ts'] == False else '[%(levelname)s] %(name)s: %(message)s',
|
format=(
|
||||||
datefmt='%Y-%m-%d %H:%M:%S',
|
"%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s: %(message)s"
|
||||||
level=logging.INFO if config['debug'] == False else logging.DEBUG
|
if not config["hide_ts"]
|
||||||
|
else "[%(levelname)s] %(name)s: %(message)s"
|
||||||
|
),
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=logging.DEBUG if config["debug"] else logging.INFO,
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.info(f'Starting: amcrest2mqtt {version}')
|
|
||||||
logger.info(f'Config loaded from {config["config_from"]}')
|
|
||||||
|
|
||||||
# Check for required config properties
|
|
||||||
if config['amcrest']['hosts'] is None:
|
|
||||||
logger.error('Missing env var: AMCREST_HOSTS or amcrest.hosts in config')
|
|
||||||
exit(1)
|
|
||||||
config['amcrest']['host_count'] = len(config['amcrest']['hosts'])
|
|
||||||
|
|
||||||
if config['amcrest']['names'] is None:
|
logger = logging.getLogger(__name__)
|
||||||
logger.error('Missing env var: AMCREST_NAMES or amcrest.names in config')
|
logger.info(f"Starting amcrest2mqtt {config['version']}")
|
||||||
exit(1)
|
logger.info(f"Config loaded from {config['config_from']} ({config['config_path']})")
|
||||||
config['amcrest']['name_count'] = len(config['amcrest']['names'])
|
|
||||||
|
|
||||||
if config['amcrest']['host_count'] != config['amcrest']['name_count']:
|
|
||||||
logger.error('The AMCREST_HOSTS and AMCREST_NAMES must have the same number of space-delimited hosts/names')
|
|
||||||
exit(1)
|
|
||||||
logger.info(f'Found {config["amcrest"]["host_count"]} host(s) defined to monitor')
|
|
||||||
|
|
||||||
if 'webrtc' in config['amcrest']:
|
|
||||||
webrtc = config['amcrest']['webrtc']
|
|
||||||
if 'host' not in webrtc:
|
|
||||||
logger.error('Missing HOST in webrtc config')
|
|
||||||
exit(1)
|
|
||||||
if 'sources' not in webrtc:
|
|
||||||
logger.error('Missing SOURCES in webrtc config')
|
|
||||||
exit(1)
|
|
||||||
config['amcrest']['webrtc_sources_count'] = len(config['amcrest']['webrtc']['sources'])
|
|
||||||
if config['amcrest']['host_count'] != config['amcrest']['webrtc_sources_count']:
|
|
||||||
logger.error('The AMCREST_HOSTS and AMCREST_WEBRTC_SOURCES must have the same number of space-delimited hosts/names')
|
|
||||||
exit(1)
|
|
||||||
if 'port' not in webrtc: webrtc['port'] = 1984
|
|
||||||
if 'link' not in webrtc: webrtc['link'] = 'stream.html'
|
|
||||||
|
|
||||||
if config['amcrest']['password'] is None:
|
|
||||||
logger.error('Please set the AMCREST_PASSWORD environment variable')
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
logger.debug("DEBUG logging is ON")
|
|
||||||
|
|
||||||
# Go!
|
# Run main loop safely
|
||||||
|
try:
|
||||||
with AmcrestMqtt(config) as mqtt:
|
with AmcrestMqtt(config) as mqtt:
|
||||||
|
try:
|
||||||
|
# Prefer a clean async run, but handle nested event loops
|
||||||
asyncio.run(mqtt.main_loop())
|
asyncio.run(mqtt.main_loop())
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "asyncio.run() cannot be called from a running event loop" in str(e):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(mqtt.main_loop())
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Shutdown requested (Ctrl+C). Exiting gracefully...")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.warning("Main loop cancelled.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Unhandled exception in main loop: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if "mqtt" in locals() and hasattr(mqtt, "api") and mqtt.api:
|
||||||
|
mqtt.api.shutdown()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error during shutdown: {e}")
|
||||||
|
logger.info("amcrest2mqtt stopped.")
|
||||||
@ -1,44 +1,121 @@
|
|||||||
# This software is licensed under the MIT License, which allows you to use,
|
import logging
|
||||||
# copy, modify, merge, publish, distribute, and sell copies of the software,
|
|
||||||
# with the requirement to include the original copyright notice and this
|
|
||||||
# permission notice in all copies or substantial portions of the software.
|
|
||||||
#
|
|
||||||
# The software is provided 'as is', without any warranty.
|
|
||||||
|
|
||||||
import ipaddress
|
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
import yaml
|
||||||
|
|
||||||
# Helper functions and callbacks
|
def get_ip_address(hostname: str) -> str:
|
||||||
def read_file(file_name):
|
"""
|
||||||
with open(file_name, 'r') as file:
|
Resolve a hostname to an IP address (IPv4 or IPv6).
|
||||||
data = file.read().replace('\n', '')
|
|
||||||
|
|
||||||
return data
|
Returns:
|
||||||
|
str: The resolved IP address, or the original hostname if resolution fails.
|
||||||
|
"""
|
||||||
|
if not hostname:
|
||||||
|
return hostname
|
||||||
|
|
||||||
def read_version():
|
try:
|
||||||
if os.path.isfile('./VERSION'):
|
# Try both IPv4 and IPv6 (AF_UNSPEC)
|
||||||
return read_file('./VERSION')
|
infos = socket.getaddrinfo(
|
||||||
|
hostname, None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
|
||||||
|
)
|
||||||
|
# Prefer IPv4 addresses if available
|
||||||
|
for family, _, _, _, sockaddr in infos:
|
||||||
|
if family == socket.AF_INET:
|
||||||
|
return sockaddr[0]
|
||||||
|
# Otherwise, fallback to first valid IPv6
|
||||||
|
return infos[0][4][0] if infos else hostname
|
||||||
|
except socket.gaierror as e:
|
||||||
|
logging.debug(f"DNS lookup failed for {hostname}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"Unexpected error resolving {hostname}: {e}")
|
||||||
|
|
||||||
return read_file('../VERSION')
|
return hostname
|
||||||
|
|
||||||
def to_gb(total):
|
|
||||||
return str(round(float(total[0]) / 1024 / 1024 / 1024, 2))
|
|
||||||
|
|
||||||
def is_ipv4(string):
|
def to_gb(bytes_value):
|
||||||
|
"""Convert bytes to a rounded string in gigabytes."""
|
||||||
try:
|
try:
|
||||||
ipaddress.IPv4Network(string)
|
gb = float(bytes_value) / (1024**3)
|
||||||
return True
|
return f"{gb:.2f} GB"
|
||||||
except ValueError:
|
except Exception:
|
||||||
return False
|
return "0.00 GB"
|
||||||
|
|
||||||
|
|
||||||
def get_ip_address(string):
|
def read_file(file_name, strip_newlines=True, default=None, encoding="utf-8"):
|
||||||
if is_ipv4(string):
|
|
||||||
return string
|
|
||||||
try:
|
try:
|
||||||
for i in socket.getaddrinfo(string, None):
|
with open(file_name, "r", encoding=encoding) as f:
|
||||||
if i[0] == socket.AddressFamily.AF_INET:
|
data = f.read()
|
||||||
return i[4][0]
|
return data.replace("\n", "") if strip_newlines else data
|
||||||
except socket.gaierror as e:
|
except FileNotFoundError:
|
||||||
raise Exception(f"Failed to resolve {string}: {e}")
|
if default is not None:
|
||||||
raise Exception(f"Failed to find IP address for {string}")
|
return default
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def read_version():
|
||||||
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
version_path = os.path.join(base_dir, "VERSION")
|
||||||
|
try:
|
||||||
|
with open(version_path, "r") as f:
|
||||||
|
return f.read().strip() or "unknown"
|
||||||
|
except FileNotFoundError:
|
||||||
|
env_version = os.getenv("APP_VERSION")
|
||||||
|
return env_version.strip() if env_version else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(path=None):
|
||||||
|
"""Load and normalize configuration from YAML file or directory."""
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
default_path = "/config/config.yaml"
|
||||||
|
|
||||||
|
# Resolve config path
|
||||||
|
config_path = path or default_path
|
||||||
|
if os.path.isdir(config_path):
|
||||||
|
config_path = os.path.join(config_path, "config.yaml")
|
||||||
|
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
config = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
# --- Normalization helpers ------------------------------------------------
|
||||||
|
def to_bool(value):
|
||||||
|
"""Coerce common truthy/falsey forms to proper bool."""
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return bool(value)
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
return str(value).strip().lower() in ("true", "1", "yes", "on")
|
||||||
|
|
||||||
|
def normalize_section(section, bool_keys):
|
||||||
|
"""Normalize booleans within a nested section."""
|
||||||
|
if not isinstance(config.get(section), dict):
|
||||||
|
return
|
||||||
|
for key in bool_keys:
|
||||||
|
if key in config[section]:
|
||||||
|
config[section][key] = to_bool(config[section][key])
|
||||||
|
|
||||||
|
# --- Global defaults + normalization -------------------------------------
|
||||||
|
config.setdefault("version", "1.0.0")
|
||||||
|
config.setdefault("debug", False)
|
||||||
|
config.setdefault("hide_ts", False)
|
||||||
|
config.setdefault("timezone", "UTC")
|
||||||
|
|
||||||
|
# normalize top-level flags
|
||||||
|
config["debug"] = to_bool(config.get("debug"))
|
||||||
|
config["hide_ts"] = to_bool(config.get("hide_ts"))
|
||||||
|
|
||||||
|
# Example: normalize booleans within sections
|
||||||
|
normalize_section("mqtt", ["tls", "retain", "clean_session"])
|
||||||
|
normalize_section("amcrest", ["webrtc", "verify_ssl"])
|
||||||
|
normalize_section("service", ["enabled", "auto_restart"])
|
||||||
|
|
||||||
|
# Add metadata for debugging/logging
|
||||||
|
config["config_path"] = os.path.abspath(config_path)
|
||||||
|
config["config_from"] = "file" if os.path.exists(config_path) else "defaults"
|
||||||
|
|
||||||
|
return config
|
||||||
|
|||||||
Loading…
Reference in New Issue