Source code for sbws.util.config

"""Util functions to manage sbws configuration files."""

import logging
import logging.config
import os
from configparser import (
    ConfigParser,
    ExtendedInterpolation,
    InterpolationMissingOptionError,
)
from string import Template
from tempfile import NamedTemporaryFile
from urllib.parse import urlparse

from sbws.globals import (
    DEFAULT_CONFIG_PATH,
    DEFAULT_LOG_CONFIG_PATH,
    SUPERVISED_RUN_DPATH,
    SUPERVISED_USER_CONFIG_PATH,
    USER_CONFIG_PATH,
)
from sbws.util.iso3166 import ISO_3166_ALPHA_2

_ALPHANUM = "abcdefghijklmnopqrstuvwxyz"
_ALPHANUM += _ALPHANUM.upper()
_ALPHANUM += "0123456789"

_SYMBOLS_NO_QUOTES = "!@#$%^&*()-_=+\\|[]{}:;/?.,<>"

_HEX = "0123456789ABCDEF"

_LOG_LEVELS = ["debug", "info", "warning", "error", "critical"]

log = logging.getLogger(__name__)


def _expand_path(path):
    """Expand path string containing shell variables and ~ constructions
    into their values. Environment variables have to have their $ escaped by
    another $. For example: $$XDG_RUNTIME_DIR/foo.bar
    """
    return os.path.expanduser(os.path.expandvars(path))


def _extend_config(conf, fname):
    """Extend ConfigParser from file configuration."""
    log.debug("Reading config file %s", fname)
    with open(fname, "rt") as fd:
        conf.read_file(fd, source=fname)
    return conf


def _get_default_config():
    """Return ConfigParser with default configuration."""
    conf = ConfigParser(
        interpolation=ExtendedInterpolation(),
        converters={"path": _expand_path},
    )
    return _extend_config(conf, DEFAULT_CONFIG_PATH)


def _obtain_user_conf_path():
    if os.environ.get("SUPERVISED") == "1":
        return SUPERVISED_USER_CONFIG_PATH
    return USER_CONFIG_PATH


def _get_user_config(args, conf=None):
    """Get user configuration.
    Search for user configuration in the default path or the path passed as
    argument and extend the configuration if they are found.
    """
    if not conf:
        conf = ConfigParser(
            interpolation=ExtendedInterpolation(),
            converters={"path": _expand_path},
        )
    else:
        assert isinstance(conf, ConfigParser)
    if args.config:
        if not os.path.isfile(args.config):
            # XXX: The logger is not configured at this stage,
            # sbws should start with a logger before reading configurations.
            print(
                "Configuration file %s not found, using defaults."
                % args.config
            )
            return conf
        print("Using configuration provided as argument %s" % args.config)
        return _extend_config(conf, args.config)
    user_config_path = _obtain_user_conf_path()
    if os.path.isfile(user_config_path):
        print("Using configuration file %s" % user_config_path)
        return _extend_config(conf, user_config_path)
    log.debug("No user config found, using defaults.")
    return conf


def _get_default_logging_config(conf=None):
    """Get default logging configuration."""
    if not conf:
        conf = ConfigParser(
            interpolation=ExtendedInterpolation(),
            converters={"path": _expand_path},
        )
    else:
        assert isinstance(conf, ConfigParser)
    return _extend_config(conf, DEFAULT_LOG_CONFIG_PATH)


[docs]def get_config(args): """Get ConfigParser interpolating all configuration files.""" conf = _get_default_config() conf = _get_default_logging_config(conf=conf) conf = _get_user_config(args, conf=conf) return conf
def _can_log_to_file(conf): """ Checks all the known reasons for why we might not be able to log to a file, and returns whether or not we think we will be able to do so. This is useful because if we can't log to a file, we might want to force logging to stdout. If we can't log to file, return False and the reason. Otherwise return True and an empty string. """ # We won't be able to get paths.log_dname from the config when we are first # initializing sbws because it depends on paths.sbws_home (by default). # If there is an issue getting this option, tell the caller that we can't # log to file. try: conf.getpath("paths", "log_dname") except InterpolationMissingOptionError as e: return False, e return True, ""
[docs]def configure_logging(args, conf): assert isinstance(conf, ConfigParser) logger = "logger_sbws" # Set the correct handler(s) based on [logging] options handlers = set() can_log_to_file, reason = _can_log_to_file(conf) if not can_log_to_file or conf.getboolean("logging", "to_stdout"): # always add to_stdout if we cannot log to file handlers.add("to_stdout") if can_log_to_file and conf.getboolean("logging", "to_file"): handlers.add("to_file") if conf.getboolean("logging", "to_syslog"): handlers.add("to_syslog") # Collect the handlers in the appropriate config option conf[logger]["handlers"] = ",".join(handlers) if "to_file" in handlers: # This is weird. # # Python's logging library expects 'args' to be a tuple ... but it has # to be stored as a string and it evals() the string. # # The first argument is the file name to which it should log. Set it to # the sbws command (like 'scanner' or 'generate') if possible, or to # 'sbws' failing that. dname = conf.getpath("paths", "log_dname") os.makedirs(dname, exist_ok=True) fname = os.path.join(dname, "{}.log".format(args.command or "sbws")) # The second argument is the file mode, and it should be left alone mode = "a" # The third is the maximum file size (in bytes) each log file should be max_bytes = conf.getint("logging", "to_file_max_bytes") # And the forth is the number of backups to keep num_backups = conf.getint("logging", "to_file_num_backups") # Now store those things as a string in the config. So dumb. conf["handler_to_file"]["args"] = str( (fname, mode, max_bytes, num_backups) ) # Set some stuff that needs config parser's interpolation conf["formatter_to_file"]["format"] = conf["logging"]["to_file_format"] conf["formatter_to_stdout"]["format"] = conf["logging"]["to_stdout_format"] conf["formatter_to_syslog"]["format"] = conf["logging"]["to_syslog_format"] conf[logger]["level"] = conf["logging"]["level"].upper() conf["handler_to_file"]["level"] = conf["logging"]["to_file_level"].upper() conf["handler_to_stdout"]["level"] = conf["logging"][ "to_stdout_level" ].upper() conf["handler_to_syslog"]["level"] = conf["logging"][ "to_syslog_level" ].upper() # If there's a log_level cli argument, the user would expect that level # in the standard output. # conf['logging']['level'] sets the lower level, but it's still needed to # set the stdout level. # It also must be set up in the end, since cli arguments have higher # priority. if args.log_level: conf["logging"]["level"] = args.log_level.upper() conf["handler_to_stdout"]["level"] = conf["logging"]["level"] # Now we configure the standard python logging system with NamedTemporaryFile("w+t") as fd: conf.write(fd) fd.seek(0, 0) logging.config.fileConfig(fd.name)
[docs]def validate_config(conf): """Checks the given conf for bad values or bad combinations of values. If there's something wrong, returns False and a list of error messages. Otherwise, return True and an empty list""" errors = [] errors.extend(_validate_general(conf)) errors.extend(_validate_cleanup(conf)) errors.extend(_validate_scanner(conf)) errors.extend(_validate_tor(conf)) errors.extend(_validate_paths(conf)) errors.extend(_validate_destinations(conf)) errors.extend(_validate_relayprioritizer(conf)) errors.extend(_validate_logging(conf)) return len(errors) < 1, errors
def _validate_cleanup(conf): errors = [] sec = "cleanup" err_tmpl = Template("$sec/$key ($val): $e") ints = { "data_files_compress_after_days": {"minimum": 1, "maximum": None}, "data_files_delete_after_days": {"minimum": 1, "maximum": None}, "v3bw_files_compress_after_days": {"minimum": 1, "maximum": None}, "v3bw_files_delete_after_days": {"minimum": 1, "maximum": None}, } all_valid_keys = list(ints.keys()) errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl)) errors.extend(_validate_section_ints(conf, sec, ints, err_tmpl)) return errors def _validate_general(conf): errors = [] sec = "general" err_tmpl = Template("$sec/$key ($val): $e") ints = { "data_period": {"minimum": 1, "maximum": None}, "circuit_timeout": {"minimum": 1, "maximum": None}, } floats = { "http_timeout": {"minimum": 0.0, "maximum": None}, } bools = { "reset_bw_ipv4_changes": {}, "reset_bw_ipv6_changes": {}, } all_valid_keys = ( list(ints.keys()) + list(floats.keys()) + list(bools.keys()) ) errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl)) errors.extend(_validate_section_ints(conf, sec, ints, err_tmpl)) errors.extend(_validate_section_floats(conf, sec, floats, err_tmpl)) errors.extend(_validate_section_bools(conf, sec, bools, err_tmpl)) return errors def _obtain_sbws_home(conf): sbws_home = conf.getpath("paths", "sbws_home") # No need for .sbws when this is the default home if sbws_home == "/var/lib/sbws/.sbws": conf["paths"]["sbws_home"] = os.path.dirname(sbws_home) def _obtain_run_dpath(conf): """Set runtime directory when sbws is run by a system service.""" xdg = os.environ.get("XDG_RUNTIME_DIR") if os.environ.get("SUPERVISED") == "1": conf["tor"]["run_dpath"] = SUPERVISED_RUN_DPATH elif xdg is not None: conf["tor"]["run_dpath"] = os.path.join(xdg, "sbws", "tor") def _validate_paths(conf): _obtain_sbws_home(conf) errors = [] sec = "paths" err_tmpl = Template("$sec/$key ($val): $e") unvalidated_keys = [ "datadir", "sbws_home", "v3bw_fname", "v3bw_dname", "state_fname", "log_dname", ] all_valid_keys = unvalidated_keys allow_missing = ["sbws_home"] errors.extend( _validate_section_keys( conf, sec, all_valid_keys, err_tmpl, allow_missing=allow_missing ) ) return errors def _validate_country(conf, sec, key, err_tmpl): errors = [] if conf[sec].get(key, None) is None: errors.append( err_tmpl.substitute( sec=sec, key=key, val=None, e="Missing country in configuration file.", ) ) return errors valid = conf[sec]["country"] in ISO_3166_ALPHA_2 if not valid: errors.append( err_tmpl.substitute( sec=sec, key=key, val=conf[sec][key], e="Not a valid ISO 3166 alpha-2 country code.", ) ) return errors def _validate_scanner(conf): errors = [] sec = "scanner" err_tmpl = Template("$sec/$key ($val): $e") ints = { "num_rtts": {"minimum": 0, "maximum": 100}, "num_downloads": {"minimum": 1, "maximum": 100}, "initial_read_request": {"minimum": 1, "maximum": None}, "measurement_threads": {"minimum": 1, "maximum": None}, "min_download_size": {"minimum": 1, "maximum": None}, "max_download_size": {"minimum": 1, "maximum": None}, } floats = { "download_toofast": {"minimum": 0.001, "maximum": None}, "download_min": {"minimum": 0.001, "maximum": None}, "download_target": {"minimum": 0.001, "maximum": None}, "download_max": {"minimum": 0.001, "maximum": None}, } all_valid_keys = ( list(ints.keys()) + list(floats.keys()) + ["nickname", "country"] ) errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl)) errors.extend(_validate_section_ints(conf, sec, ints, err_tmpl)) errors.extend(_validate_section_floats(conf, sec, floats, err_tmpl)) valid, error_msg = _validate_nickname(conf[sec], "nickname") if not valid: errors.append( err_tmpl.substitute( sec=sec, key="nickname", val=conf[sec]["nickname"], e=error_msg ) ) errors.extend(_validate_country(conf, sec, "country", err_tmpl)) return errors def _validate_tor(conf): _obtain_run_dpath(conf) errors = [] sec = "tor" err_tmpl = Template("$sec/$key ($val): $e") unvalidated_keys = [ "datadir", "run_dpath", "control_socket", "pid", "log", "external_control_port", "extra_lines", ] all_valid_keys = unvalidated_keys errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl)) return errors def _validate_relayprioritizer(conf): errors = [] sec = "relayprioritizer" err_tmpl = Template("$sec/$key ($val): $e") ints = { "min_relays": {"minimum": 1, "maximum": None}, } floats = { "fraction_relays": {"minimum": 0.0, "maximum": 1.0}, } bools = { "measure_authorities": {}, } all_valid_keys = ( list(ints.keys()) + list(floats.keys()) + list(bools.keys()) ) errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl)) errors.extend(_validate_section_ints(conf, sec, ints, err_tmpl)) errors.extend(_validate_section_floats(conf, sec, floats, err_tmpl)) errors.extend(_validate_section_bools(conf, sec, bools, err_tmpl)) return errors def _validate_logging(conf): errors = [] sec = "logging" err_tmpl = Template("$sec/$key ($val): $e") enums = { "level": {"choices": _LOG_LEVELS}, "to_file_level": {"choices": _LOG_LEVELS}, "to_stdout_level": {"choices": _LOG_LEVELS}, "to_syslog_level": {"choices": _LOG_LEVELS}, } bools = { "to_file": {}, "to_stdout": {}, "to_syslog": {}, } ints = { "to_file_max_bytes": {"minimum": 0, "maximum": None}, "to_file_num_backups": {"minimum": 0, "maximum": None}, } unvalidated = [ "format", "to_file_format", "to_stdout_format", "to_syslog_format", ] all_valid_keys = ( list(bools.keys()) + list(enums.keys()) + list(ints.keys()) + unvalidated ) errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl)) errors.extend(_validate_section_bools(conf, sec, bools, err_tmpl)) errors.extend(_validate_section_enums(conf, sec, enums, err_tmpl)) return errors def _validate_destinations(conf): errors = [] sec = "destinations" section = conf[sec] err_tmpl = Template("$sec/$key ($val): $e") dest_sections = [] for key in section.keys(): if key == "usability_test_interval": value = section[key] valid, error_msg = _validate_int(section, key, minimum=1) if not valid: errors.append( err_tmpl.substitute( sec=sec, key=key, val=value, e=error_msg ) ) continue value = section[key] valid, error_msg = _validate_boolean(section, key) if not valid: errors.append( err_tmpl.substitute(sec=sec, key=key, val=value, e=error_msg) ) continue assert valid if section.getboolean(key): dest_sections.append("{}.{}".format(sec, key)) urls = { "url": {}, } all_valid_keys = list(urls.keys()) + [ "verify", "country", "max_num_failures", ] for sec in dest_sections: if sec not in conf: errors.append( "{} is an enabled destination but is not a " "section in the config".format(sec) ) continue errors.extend( _validate_section_keys( conf, sec, all_valid_keys, err_tmpl, allow_missing=["verify", "max_num_failures"], ) ) errors.extend(_validate_section_urls(conf, sec, urls, err_tmpl)) errors.extend(_validate_country(conf, sec, "country", err_tmpl)) return errors def _validate_section_keys(conf, sec, keys, tmpl, allow_missing=None): if allow_missing is None: allow_missing = [] errors = [] section = conf[sec] # Find keys that exist in the user's config that are not known for key in section: if key not in keys: errors.append( tmpl.substitute( sec=sec, key=key, val=section[key], e="Unknown key" ) ) # Find keys that don't exist in the user's config that should for key in keys: if key not in section and key not in allow_missing: errors.append( tmpl.substitute( sec=sec, key=key, val="[NOT SET]", e="Missing key" ) ) return errors def _validate_section_ints(conf, sec, ints, tmpl): errors = [] section = conf[sec] for key in ints: valid, error = _validate_int( section, key, minimum=ints[key]["minimum"], maximum=ints[key]["maximum"], ) if not valid: errors.append( tmpl.substitute(sec=sec, key=key, val=section[key], e=error) ) return errors def _validate_section_floats(conf, sec, floats, tmpl): errors = [] section = conf[sec] for key in floats: valid, error = _validate_float( section, key, minimum=floats[key]["minimum"], maximum=floats[key]["maximum"], ) if not valid: errors.append( tmpl.substitute(sec=sec, key=key, val=section[key], e=error) ) return errors def _validate_section_hosts(conf, sec, hosts, tmpl): errors = [] section = conf[sec] for key in hosts: valid, error = _validate_host(section, key) if not valid: errors.append( tmpl.substitute(sec=sec, key=key, val=section[key], e=error) ) return errors def _validate_section_ports(conf, sec, ports, tmpl): errors = [] section = conf[sec] for key in ports: valid, error = _validate_int(section, key, minimum=1, maximum=2 ** 16) if not valid: errors.append( tmpl.substitute( sec=sec, key=key, val=section[key], e="Not a valid port ({})".format(error), ) ) return errors def _validate_section_bools(conf, sec, bools, tmpl): errors = [] section = conf[sec] for key in bools: valid, error = _validate_boolean(section, key) if not valid: errors.append( tmpl.substitute( sec=sec, key=key, val=section[key], e="Not a valid boolean string ({})".format(error), ) ) return errors def _validate_section_fingerprints(conf, sec, fps, tmpl): errors = [] section = conf[sec] for key in fps: valid, error = _validate_fingerprint(section, key) if not valid: errors.append( tmpl.substitute( sec=sec, key=key, val=section[key], e="Not a valid fingerprint ({})".format(error), ) ) return errors def _validate_section_urls(conf, sec, urls, tmpl): errors = [] section = conf[sec] for key in urls: valid, error = _validate_url(section, key) if not valid: errors.append( tmpl.substitute( sec=sec, key=key, val=section[key], e="Not a valid url ({})".format(error), ) ) return errors def _validate_section_enums(conf, sec, enums, tmpl): errors = [] section = conf[sec] for key in enums: choices = enums[key]["choices"] valid, error = _validate_enum(section, key, choices) if not valid: errors.append( tmpl.substitute( sec=sec, key=key, val=section[key], e="Not a valid enum choice ({})".format( ", ".join(choices) ), ) ) return errors def _validate_enum(section, key, choices): value = section[key] if value not in choices: return False, "{} not in allowed choices: {}".format( value, ", ".join(choices) ) return True, "" def _validate_url(section, key): value = section[key] url = urlparse(value) if not url.netloc: return False, "Does not appear to contain a hostname" # It should be possible to have an URL that starts by http:// that uses # TLS,but python requests is just checking the scheme starts by https # when verifying certificate: # https://github.com/requests/requests/blob/master/requests/adapters.py#L215 # noqa # When the scheme is https but the protocol is not TLS, requests will hang. if url.scheme != "https" and not url.netloc.startswith("127.0.0.1"): return False, "URL scheme must be HTTPS (except for the test server)" return True, "" def _validate_int(section, key, minimum=None, maximum=None): try: value = section.getint(key) except ValueError as e: return False, e if minimum is not None: assert isinstance(minimum, int) if value < minimum: return False, "Cannot be less than {}".format(minimum) if maximum is not None: assert isinstance(maximum, int) if value > maximum: return False, "Cannot be greater than {}".format(maximum) return True, "" def _validate_boolean(section, key): try: section.getboolean(key) except ValueError as e: return False, e return True, "" def _validate_float(section, key, minimum=None, maximum=None): try: value = section.getfloat(key) except ValueError as e: return False, e if minimum is not None: assert isinstance(minimum, float) if value < minimum: return False, "Cannot be less than {}".format(minimum) if maximum is not None: assert isinstance(maximum, float) if value > maximum: return False, "Cannot be greater than {}".format(maximum) return True, "" def _validate_host(section, key): # XXX: Implement this return True, "" def _validate_fingerprint(section, key): alphabet = _HEX length = 40 return _validate_string( section, key, min_len=length, max_len=length, alphabet=alphabet ) def _validate_nickname(section, key): alphabet = _ALPHANUM + _SYMBOLS_NO_QUOTES min_len = 1 max_len = 32 return _validate_string( section, key, min_len=min_len, max_len=max_len, alphabet=alphabet ) def _validate_string( section, key, min_len=None, max_len=None, alphabet=None, starts_with=None ): s = section[key] if min_len is not None and len(s) < min_len: return False, "{} is below minimum allowed length {}".format( len(s), min_len ) if max_len is not None and len(s) > max_len: return False, "{} is above maximum allowed length {}".format( len(s), max_len ) if alphabet is not None: for i, c in enumerate(s): if c not in alphabet: return ( False, "Letter {} at position {} is not in allowed " 'characters "{}"'.format(c, i, alphabet), ) if starts_with is not None: if not s.startswith(starts_with): return False, "{} does not start with {}".format(s, starts_with) return True, ""