diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f87564a..c8874d6 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -30,6 +30,7 @@ jobs: python -m pip install --upgrade pip pip install \ absl-py \ + black \ freezegun \ mypy \ passlib \ @@ -38,6 +39,9 @@ jobs: python-dateutil \ types-passlib \ types-python-dateutil + - name: Check formatting + run: | + black --check --diff . - name: Run mypy run: | mypy diff --git a/conftest.py b/conftest.py index 1fa928a..2093625 100644 --- a/conftest.py +++ b/conftest.py @@ -16,8 +16,8 @@ import pytest -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope="session", autouse=True) def _parse_absl_flags(): # absltest doesn't work well without absl flags being parsed, which pytest # doesn't do. This fixture works around that. - flags.FLAGS(('pytest',)) + flags.FLAGS(("pytest",)) diff --git a/salt/file/accounts/generate_lemonldap_ng_ini.py b/salt/file/accounts/generate_lemonldap_ng_ini.py index 3e02e27..163744a 100644 --- a/salt/file/accounts/generate_lemonldap_ng_ini.py +++ b/salt/file/accounts/generate_lemonldap_ng_ini.py @@ -25,18 +25,19 @@ def _args(): parser = argparse.ArgumentParser( - description='Generate dynamic parts of lemonldap-ng.ini.') + description="Generate dynamic parts of lemonldap-ng.ini." + ) parser.add_argument( - '--input', + "--input", type=pathlib.Path, required=True, - help='Path to the static lemonldap-ng.ini to read.', + help="Path to the static lemonldap-ng.ini to read.", ) parser.add_argument( - '--output', + "--output", type=pathlib.Path, required=True, - help='Path to the dynamic lemonldap-ng.ini to replace.', + help="Path to the dynamic lemonldap-ng.ini to replace.", ) return parser.parse_args() @@ -52,12 +53,12 @@ def _portal_lines() -> Sequence[str]: # algorithms in the static part of the config file to match. private_key = subprocess.run( ( - 'openssl', - 'genpkey', - '-algorithm', - 'RSA', - '-pkeyopt', - 'rsa_keygen_bits:3072', + "openssl", + "genpkey", + "-algorithm", + "RSA", + "-pkeyopt", + "rsa_keygen_bits:3072", ), stdout=subprocess.PIPE, # See https://github.com/openssl/openssl/issues/13177 @@ -66,7 +67,7 @@ def _portal_lines() -> Sequence[str]: text=True, ).stdout public_key = subprocess.run( - ('openssl', 'pkey', '-pubout'), + ("openssl", "pkey", "-pubout"), input=private_key, stdout=subprocess.PIPE, check=True, @@ -74,26 +75,26 @@ def _portal_lines() -> Sequence[str]: ).stdout key_id = str(uuid.uuid4()) return ( - 'oidcServicePrivateKeySig = < None: args = _args() - with args.input.open('rt') as input_file: + with args.input.open("rt") as input_file: config = list(input_file) - portal_index = config.index('[portal]\n') - config[portal_index + 1:portal_index + 1] = _portal_lines() + portal_index = config.index("[portal]\n") + config[portal_index + 1 : portal_index + 1] = _portal_lines() with tempfile.NamedTemporaryFile( - mode='wt', - dir=args.output.parent, - delete=False, + mode="wt", + dir=args.output.parent, + delete=False, ) as output_tempfile: output_tempfile.writelines(config) input_stat = args.input.stat() @@ -102,5 +103,5 @@ def main() -> None: os.replace(output_tempfile.name, args.output) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/salt/file/backup/repo/borg_require_recent_archive.py b/salt/file/backup/repo/borg_require_recent_archive.py index 88518bb..00075c4 100644 --- a/salt/file/backup/repo/borg_require_recent_archive.py +++ b/salt/file/backup/repo/borg_require_recent_archive.py @@ -24,25 +24,26 @@ def _args(): parser = argparse.ArgumentParser( - description='Print a message if the latest archive is too old.') + description="Print a message if the latest archive is too old." + ) parser.add_argument( - '--repository', + "--repository", type=str, required=True, - help='Repository to check.', + help="Repository to check.", ) parser.add_argument( - '--max-age', + "--max-age", default=datetime.timedelta(days=2, hours=12), type=lambda arg: datetime.timedelta(seconds=float(arg)), - help='How old to warn about, in seconds.', + help="How old to warn about, in seconds.", ) parser.add_argument( - 'borg_option', - nargs='*', + "borg_option", + nargs="*", default=[], type=str, - help='Borg common options.', + help="Borg common options.", ) return parser.parse_args() @@ -51,11 +52,11 @@ def main() -> None: args = _args() repository_list_raw = subprocess.run( ( - 'borg', + "borg", *args.borg_option, - 'list', - '--json', - '--last=5', + "list", + "--json", + "--last=5", args.repository, ), stdout=subprocess.PIPE, @@ -63,17 +64,18 @@ def main() -> None: ).stdout now = datetime.datetime.now(tz=datetime.timezone.utc) repository_list = json.loads(repository_list_raw) - archives = repository_list['archives'] + archives = repository_list["archives"] if not archives: - print('No archives.') + print("No archives.") return last_archive_time = datetime.datetime.fromisoformat( - archives[-1]['start']).astimezone(datetime.timezone.utc) + archives[-1]["start"] + ).astimezone(datetime.timezone.utc) if last_archive_time < now - args.max_age: - print(f'Latest archive is older than {args.max_age}. Recent archives:') + print(f"Latest archive is older than {args.max_age}. Recent archives:") pprint.pprint(archives) return -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/salt/file/crypto/x509/boilerplate_certificate.py b/salt/file/crypto/x509/boilerplate_certificate.py index 0dece0c..e1eb2cc 100644 --- a/salt/file/crypto/x509/boilerplate_certificate.py +++ b/salt/file/crypto/x509/boilerplate_certificate.py @@ -62,38 +62,39 @@ def _args(): parser = argparse.ArgumentParser( - description='Create a key pair and EE certificate for TLS.') + description="Create a key pair and EE certificate for TLS." + ) parser.add_argument( - '--name', + "--name", required=True, - help='DNS name for the EE certificate.', + help="DNS name for the EE certificate.", ) parser.add_argument( - '--key', + "--key", type=pathlib.Path, required=True, - help='Path to write the EE certificate\'s private key to.', + help="Path to write the EE certificate's private key to.", ) parser.add_argument( - '--cert', + "--cert", type=pathlib.Path, required=True, - help='Path to write the EE certificate to.', + help="Path to write the EE certificate to.", ) parser.add_argument( - '--key-algorithm', + "--key-algorithm", required=True, - help='See the -algorithm argument to openssl genpkey.', + help="See the -algorithm argument to openssl genpkey.", ) parser.add_argument( - '--key-option', - nargs='*', - help='See the -pkeyopt argument to openssl genpkey.', + "--key-option", + nargs="*", + help="See the -pkeyopt argument to openssl genpkey.", ) parser.add_argument( - '--days', + "--days", required=True, - help='See the -days argument to openssl req and x509.', + help="See the -days argument to openssl req and x509.", ) return parser.parse_args() @@ -101,25 +102,25 @@ def _args(): def _signature_args(args) -> Sequence[str]: # See https://cabforum.org/wp-content/uploads/CA-Browser-Forum-BR-1.8.0.pdf # section 7.1.3.2 for restrictions on the digest based on the key type. - if args.key_algorithm == 'EC': - if args.key_option == ['ec_paramgen_curve:P-384']: - return ('-sha384',) - raise NotImplementedError(f'{args.key_algorithm=}, {args.key_option=}') + if args.key_algorithm == "EC": + if args.key_option == ["ec_paramgen_curve:P-384"]: + return ("-sha384",) + raise NotImplementedError(f"{args.key_algorithm=}, {args.key_option=}") def main() -> None: args = _args() - genpkey_args = ['-algorithm', args.key_algorithm] + genpkey_args = ["-algorithm", args.key_algorithm] for key_option in args.key_option: - genpkey_args.extend(('-pkeyopt', key_option)) + genpkey_args.extend(("-pkeyopt", key_option)) signature_args = _signature_args(args) with tempfile.TemporaryDirectory() as tempdir_name: tempdir = pathlib.Path(tempdir_name) - with tempdir.joinpath('openssl.cnf').open(mode='wt') as openssl_cnf: + with tempdir.joinpath("openssl.cnf").open(mode="wt") as openssl_cnf: # Note that subjectKeyIdentifier comes before authorityKeyIdentifier # here despite the order being different in # https://datatracker.ietf.org/doc/html/rfc5280#section-4.2 because @@ -127,7 +128,8 @@ def main() -> None: # # X509 V3 routines:v2i_AUTHORITY_KEYID:unable to get issuer keyid:../crypto/x509v3/v3_akey.c:143 openssl_cnf.write( - textwrap.dedent(f""" + textwrap.dedent( + f""" [req] string_mask = utf8only prompt = no @@ -149,79 +151,81 @@ def main() -> None: subjectAltName = DNS:{args.name} basicConstraints = critical, CA:FALSE extendedKeyUsage = serverAuth, clientAuth - """)) + """ + ) + ) ca_private_key = subprocess.run( - ('openssl', 'genpkey', *genpkey_args), + ("openssl", "genpkey", *genpkey_args), stdout=subprocess.PIPE, check=True, ).stdout subprocess.run( ( - 'openssl', - 'req', - '-x509', - '-batch', - '-key', - '/dev/stdin', - '-config', - str(tempdir.joinpath('openssl.cnf')), - '-extensions', - 'x509_ca_extensions', - '-subj', - f'/CN=Boilerplate CA for {args.name}', - '-days', + "openssl", + "req", + "-x509", + "-batch", + "-key", + "/dev/stdin", + "-config", + str(tempdir.joinpath("openssl.cnf")), + "-extensions", + "x509_ca_extensions", + "-subj", + f"/CN=Boilerplate CA for {args.name}", + "-days", args.days, *signature_args, - '-out', - str(tempdir.joinpath('ca-cert.pem')), + "-out", + str(tempdir.joinpath("ca-cert.pem")), ), input=ca_private_key, check=True, ) subprocess.run( - ('openssl', 'genpkey', *genpkey_args, '-out', str(args.key)), + ("openssl", "genpkey", *genpkey_args, "-out", str(args.key)), check=True, ) subprocess.run( ( - 'openssl', - 'req', - '-new', - '-batch', - '-key', + "openssl", + "req", + "-new", + "-batch", + "-key", str(args.key), - '-config', - str(tempdir.joinpath('openssl.cnf')), - '-subj', - f'/CN={args.name}', + "-config", + str(tempdir.joinpath("openssl.cnf")), + "-subj", + f"/CN={args.name}", *signature_args, - '-out', - str(tempdir.joinpath('ee-req.pem')), + "-out", + str(tempdir.joinpath("ee-req.pem")), ), check=True, ) subprocess.run( ( - 'openssl', - 'x509', - '-req', - '-in', - str(tempdir.joinpath('ee-req.pem')), - '-CA', - str(tempdir.joinpath('ca-cert.pem')), - '-CAkey', - '/dev/stdin', - '-CAcreateserial', - '-extfile', - str(tempdir.joinpath('openssl.cnf')), - '-extensions', - 'x509_ee_extensions', - '-days', + "openssl", + "x509", + "-req", + "-in", + str(tempdir.joinpath("ee-req.pem")), + "-CA", + str(tempdir.joinpath("ca-cert.pem")), + "-CAkey", + "/dev/stdin", + "-CAcreateserial", + "-extfile", + str(tempdir.joinpath("openssl.cnf")), + "-extensions", + "x509_ee_extensions", + "-days", args.days, *signature_args, - '-out', + "-out", str(args.cert), ), input=ca_private_key, @@ -229,5 +233,5 @@ def main() -> None: ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/salt/file/disk_usage/disk_usage_at_least.py b/salt/file/disk_usage/disk_usage_at_least.py index 20d5c76..98bb3e1 100644 --- a/salt/file/disk_usage/disk_usage_at_least.py +++ b/salt/file/disk_usage/disk_usage_at_least.py @@ -21,14 +21,16 @@ def _lvm_pool_usage(*, min_percent): - if not shutil.which('lvs'): - return '' + if not shutil.which("lvs"): + return "" lvs = subprocess.run( ( - 'lvs', - '-S', - ('lv_layout=pool,' - f'(data_percent>={min_percent}||metadata_percent>={min_percent})'), + "lvs", + "-S", + ( + "lv_layout=pool," + f"(data_percent>={min_percent}||metadata_percent>={min_percent})" + ), ), stdout=subprocess.PIPE, text=True, @@ -39,7 +41,7 @@ def _lvm_pool_usage(*, min_percent): def _filesystem_usage(*, min_percent): df = subprocess.run( - ('df', '-h'), + ("df", "-h"), stdout=subprocess.PIPE, text=True, check=True, @@ -47,29 +49,31 @@ def _filesystem_usage(*, min_percent): df_lines = df.stdout.splitlines() df_header = df_lines[0] df_lines_to_print = [ - line for line in df_lines[1:] - if float(line.split()[4].rstrip('%')) >= min_percent + line + for line in df_lines[1:] + if float(line.split()[4].rstrip("%")) >= min_percent ] if df_lines_to_print: - return ''.join(line + '\n' for line in (df_header, *df_lines_to_print)) + return "".join(line + "\n" for line in (df_header, *df_lines_to_print)) else: - return '' + return "" def main(): arg_parser = argparse.ArgumentParser( - description='Conditionally print disk usage.') + description="Conditionally print disk usage." + ) arg_parser.add_argument( - '--lvm-pool-threshold', + "--lvm-pool-threshold", type=float, required=True, - help='Minimum usage percent to print for LVM pools.', + help="Minimum usage percent to print for LVM pools.", ) arg_parser.add_argument( - '--fs-threshold', + "--fs-threshold", type=float, required=True, - help='Minimum usage percent to print for filesystems.', + help="Minimum usage percent to print for filesystems.", ) args = arg_parser.parse_args() @@ -77,8 +81,8 @@ def main(): _lvm_pool_usage(min_percent=args.lvm_pool_threshold), _filesystem_usage(min_percent=args.fs_threshold), ) - sys.stdout.write('\n'.join(section for section in sections if section)) + sys.stdout.write("\n".join(section for section in sections if section)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/salt/file/mail/storage/monitor_subscriptions.py b/salt/file/mail/storage/monitor_subscriptions.py index b3eb8e4..1d7aa5f 100644 --- a/salt/file/mail/storage/monitor_subscriptions.py +++ b/salt/file/mail/storage/monitor_subscriptions.py @@ -36,27 +36,38 @@ def _notify( unsubscribed = mailboxes - subscribed if unsubscribed: - sections.append(''.join(( - 'Unsubscribed mailboxes:\n', - *(f' {mailbox}\n' for mailbox in unsubscribed), - ))) + sections.append( + "".join( + ( + "Unsubscribed mailboxes:\n", + *(f" {mailbox}\n" for mailbox in unsubscribed), + ) + ) + ) nonexistent_subscribed = subscribed - mailboxes if nonexistent_subscribed: - sections.append(''.join(( - 'Nonexistent subscriptions:\n', - *(f' {subscription}\n' for subscription in nonexistent_subscribed), - ))) + sections.append( + "".join( + ( + "Nonexistent subscriptions:\n", + *( + f" {subscription}\n" + for subscription in nonexistent_subscribed + ), + ) + ) + ) notification = email.message.EmailMessage() - notification['To'] = user - notification['Subject'] = 'subscriptions do not match mailboxes' - notification.set_content('\n'.join(sections)) + notification["To"] = user + notification["Subject"] = "subscriptions do not match mailboxes" + notification.set_content("\n".join(sections)) subprocess.run( ( - '/usr/sbin/sendmail', - '-i', - '-t', + "/usr/sbin/sendmail", + "-i", + "-t", ), check=True, input=bytes(notification), @@ -65,7 +76,7 @@ def _notify( def main() -> None: users = subprocess.run( - ('doveadm', 'user', '*'), + ("doveadm", "user", "*"), check=True, stdout=subprocess.PIPE, text=True, @@ -77,31 +88,32 @@ def main() -> None: # mailboxes. mailbox_statuses = subprocess.run( ( - 'doveadm', - '-f', - 'tab', - 'mailbox', - 'status', - '-u', + "doveadm", + "-f", + "tab", + "mailbox", + "status", + "-u", user, - 'guid', - '*', + "guid", + "*", ), check=True, stdout=subprocess.PIPE, text=True, ).stdout.splitlines()[1:] - mailboxes = {status.split('\t')[0] for status in mailbox_statuses} + mailboxes = {status.split("\t")[0] for status in mailbox_statuses} subscribed = set( subprocess.run( - ('doveadm', 'mailbox', 'list', '-u', user, '-s'), + ("doveadm", "mailbox", "list", "-u", user, "-s"), check=True, stdout=subprocess.PIPE, text=True, - ).stdout.splitlines()) + ).stdout.splitlines() + ) if mailboxes != subscribed: _notify(user=user, mailboxes=mailboxes, subscribed=subscribed) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/salt/file/mail/storage/spam_train.py b/salt/file/mail/storage/spam_train.py index 34c2f57..ead7e64 100644 --- a/salt/file/mail/storage/spam_train.py +++ b/salt/file/mail/storage/spam_train.py @@ -28,17 +28,17 @@ import subprocess import tempfile -_SPAM_FOLDERS = ('.Junk',) +_SPAM_FOLDERS = (".Junk",) _FORGET_FOLDERS = ( - '.Archive', - '.Drafts', - '.Sent', - '.Trash', + ".Archive", + ".Drafts", + ".Sent", + ".Trash", ) def _is_inclusive_subfolder(name: str, tests: Collection[str]) -> bool: - return name in tests or name.startswith(tuple(f'{test}.' for test in tests)) + return name in tests or name.startswith(tuple(f"{test}." for test in tests)) def _sa_learn( @@ -50,16 +50,16 @@ def _sa_learn( # TODO(https://bz.apache.org/SpamAssassin/show_bug.cgi?id=8146): Remove this # special handling of empty folders. try: - next((folder / 'cur').iterdir()) + next((folder / "cur").iterdir()) except StopIteration: return sa_learn_result = subprocess.run( ( - 'sa-learn', - '--quiet', - f'--dbpath={dbpath}', + "sa-learn", + "--quiet", + f"--dbpath={dbpath}", type_arg, - '.', + ".", ), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, @@ -70,36 +70,36 @@ def _sa_learn( if any( line for line in sa_learn_result.stdout.splitlines() - if not line.startswith('Bad UTF7 data escape at ') + if not line.startswith("Bad UTF7 data escape at ") ): - print(f'{folder}:\n{sa_learn_result.stdout}') + print(f"{folder}:\n{sa_learn_result.stdout}") def main() -> None: - user_dir = pathlib.Path(os.environ['USER_DIR']) - maildir = pathlib.Path(os.environ['MAILDIR']) + user_dir = pathlib.Path(os.environ["USER_DIR"]) + maildir = pathlib.Path(os.environ["MAILDIR"]) with tempfile.TemporaryDirectory() as temp_dir: temp_path = pathlib.Path(temp_dir) - temp_path.joinpath('spamassassin').symlink_to(user_dir) - dbpath = temp_path.joinpath('spamassassin').joinpath('bayes') - _sa_learn('--ham', maildir, dbpath=dbpath) + temp_path.joinpath("spamassassin").symlink_to(user_dir) + dbpath = temp_path.joinpath("spamassassin").joinpath("bayes") + _sa_learn("--ham", maildir, dbpath=dbpath) for subdir in maildir.iterdir(): if _is_inclusive_subfolder(subdir.name, _SPAM_FOLDERS): - _sa_learn('--spam', subdir, dbpath=dbpath) + _sa_learn("--spam", subdir, dbpath=dbpath) elif _is_inclusive_subfolder(subdir.name, _FORGET_FOLDERS): - _sa_learn('--forget', subdir, dbpath=dbpath) - elif subdir.name.startswith('.'): - _sa_learn('--ham', subdir, dbpath=dbpath) + _sa_learn("--forget", subdir, dbpath=dbpath) + elif subdir.name.startswith("."): + _sa_learn("--ham", subdir, dbpath=dbpath) subprocess.run( ( - 'sa-learn', - '--quiet', - f'--dbpath={dbpath}', - '--force-expire', + "sa-learn", + "--quiet", + f"--dbpath={dbpath}", + "--force-expire", ), check=True, ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/salt/file/mail/web/generate_dynamic_config.py b/salt/file/mail/web/generate_dynamic_config.py index a805fd5..7a051d5 100644 --- a/salt/file/mail/web/generate_dynamic_config.py +++ b/salt/file/mail/web/generate_dynamic_config.py @@ -25,30 +25,31 @@ def _args(): parser = argparse.ArgumentParser( - description='Generate dynamic parts of Roundcube config.') + description="Generate dynamic parts of Roundcube config." + ) parser.add_argument( - '--key-bits', + "--key-bits", type=int, required=True, help="Number of random bits in $config['des_key']", ) parser.add_argument( - '--cipher-method', + "--cipher-method", type=str, required=True, help="Value for $config['cipher_method']", ) parser.add_argument( - '--group', + "--group", type=str, required=True, - help='Group that should be able to read the dynamic config file.', + help="Group that should be able to read the dynamic config file.", ) parser.add_argument( - '--output', + "--output", type=pathlib.Path, required=True, - help='Path to the dynamic config file to write.', + help="Path to the dynamic config file to write.", ) return parser.parse_args() @@ -56,22 +57,25 @@ def _args(): def main() -> None: args = _args() key = secrets.token_bytes(nbytes=args.key_bits // 8) - key_php_escaped = ''.join(f'\\x{byte:02x}' for byte in key) + key_php_escaped = "".join(f"\\x{byte:02x}" for byte in key) with tempfile.NamedTemporaryFile( - mode='wt', - dir=args.output.parent, - delete=False, + mode="wt", + dir=args.output.parent, + delete=False, ) as output_tempfile: output_tempfile.write( - textwrap.dedent(f"""\ + textwrap.dedent( + f"""\ now: raise RuntimeError( - f'last_sent {self.last_sent_parsed} is in the future (after ' - f'{now}).') + f"last_sent {self.last_sent_parsed} is in the future (after " + f"{now})." + ) def set_last_sent(self, value: datetime.datetime) -> None: if value.tzinfo is not datetime.timezone.utc: - raise ValueError('last_sent must be UTC') - self.last_sent = value.strftime('%Y%m%dT%H%M%SZ') + raise ValueError("last_sent must be UTC") + self.last_sent = value.strftime("%Y%m%dT%H%M%SZ") self.last_sent_parsed = value def _parse_args(args: Sequence[str]) -> argparse.Namespace: - parser = argparse.ArgumentParser(description='Send scheduled TODO emails.') + parser = argparse.ArgumentParser(description="Send scheduled TODO emails.") parser.add_argument( - '--config', + "--config", type=pathlib.Path, required=True, - help='Path to config file.', + help="Path to config file.", ) parser.add_argument( - '--state', + "--state", type=pathlib.Path, required=True, - help='Path to state file.', + help="Path to state file.", ) parser.add_argument( - '--max-occurrences', + "--max-occurrences", type=int, default=10, - help='Maximum number of occurrences to show at once.', + help="Maximum number of occurrences to show at once.", ) return parser.parse_args(args) def _parse_config(config_filename: pathlib.Path) -> Mapping[str, _TodoConfig]: - with open(config_filename, mode='rb') as config_file: + with open(config_filename, mode="rb") as config_file: raw_config = json.load(config_file) config = {} for group_id, group_config in raw_config.items(): - defaults = group_config.pop('defaults', {}) - todos = group_config.pop('todos') + defaults = group_config.pop("defaults", {}) + todos = group_config.pop("todos") if group_config: raise ValueError( - f'Unexpected group config keys: {list(group_config)!r}') + f"Unexpected group config keys: {list(group_config)!r}" + ) for todo_id, todo_config in todos.items(): - config[f'{group_id}.{todo_id}'] = ( # - _TodoConfig(**(defaults | todo_config))) + config[f"{group_id}.{todo_id}"] = _TodoConfig( # + **(defaults | todo_config) + ) return config @@ -169,14 +185,16 @@ def _parse_state( now: datetime.datetime, ) -> collections.defaultdict[str, _TodoState]: try: - with open(state_filename, mode='rb') as state_file: + with open(state_filename, mode="rb") as state_file: raw_state = json.load(state_file) except FileNotFoundError: raw_state = {} return collections.defaultdict( lambda: _TodoState(now=now), - ((todo_id, _TodoState(**todo_state, now=now)) - for todo_id, todo_state in raw_state.items()), + ( + (todo_id, _TodoState(**todo_state, now=now)) + for todo_id, todo_state in raw_state.items() + ), ) @@ -192,9 +210,9 @@ def _save_state( if field.init } with tempfile.NamedTemporaryFile( - mode='wt', - dir=state_filename.parent, - delete=False, + mode="wt", + dir=state_filename.parent, + delete=False, ) as state_file_new: json.dump(raw_state, state_file_new) os.replace(state_file_new.name, state_filename) @@ -211,24 +229,25 @@ def _send_email( message = email.message.EmailMessage() for header, value in config.email_headers.items(): message[header] = value - message['Subject'] = config.summary + ('' if comment is None else - f' ({comment})') - message['Todo-Id'] = todo_id - message['Todo-Summary'] = config.summary - message['Todo-Timezone'] = config.timezone - message['Todo-Start'] = config.start + message["Subject"] = config.summary + ( + "" if comment is None else f" ({comment})" + ) + message["Todo-Id"] = todo_id + message["Todo-Summary"] = config.summary + message["Todo-Timezone"] = config.timezone + message["Todo-Start"] = config.start if config.recurrence_rule is not None: - message['Todo-Recurrence-Rule'] = config.recurrence_rule + message["Todo-Recurrence-Rule"] = config.recurrence_rule if config.description is not None: - message.add_attachment(config.description, disposition='inline') + message.add_attachment(config.description, disposition="inline") if extra: message.add_attachment( - '\n\n'.join('\n'.join(section) for section in extra), - disposition='inline', - filename='extra-information', + "\n\n".join("\n".join(section) for section in extra), + disposition="inline", + filename="extra-information", ) subprocess_run( - ('/usr/sbin/sendmail', '-i', '-t'), + ("/usr/sbin/sendmail", "-i", "-t"), check=True, input=bytes(message), ) @@ -247,8 +266,10 @@ def _handle_todo( if now < config.start_parsed: return # Not ready to send yet. if config.recurrence_rule_parsed is None: - if (state.last_sent_parsed is not None and - state.last_sent_parsed >= config.start_parsed): + if ( + state.last_sent_parsed is not None + and state.last_sent_parsed >= config.start_parsed + ): return # Already sent. comment = None else: @@ -258,35 +279,49 @@ def _handle_todo( itertools.takewhile( lambda occurrence: occurrence <= now, config.recurrence_rule_parsed.xafter( - (config.start_parsed if state.last_sent_parsed is None else - state.last_sent_parsed), + ( + config.start_parsed + if state.last_sent_parsed is None + else state.last_sent_parsed + ), count=max_occurrences + 1, inc=(state.last_sent_parsed is None), - ))) + ), + ) + ) if not included_occurrences: return elif len(included_occurrences) == 1: comment = None elif len(included_occurrences) > max_occurrences: - comment = f'x{max_occurrences}+' + comment = f"x{max_occurrences}+" else: - comment = f'x{len(included_occurrences)}' - extra.append([ - 'Occurrences included in this email:', - *(str(occurrence) if i < max_occurrences else '...' - for i, occurrence in enumerate(included_occurrences)), - ]) + comment = f"x{len(included_occurrences)}" + extra.append( + [ + "Occurrences included in this email:", + *( + str(occurrence) if i < max_occurrences else "..." + for i, occurrence in enumerate(included_occurrences) + ), + ] + ) next_occurrences = tuple( config.recurrence_rule_parsed.xafter( now, count=max_occurrences + 1, inc=False, - )) - extra.append([ - 'Next occurrences:', - *(str(occurrence) if i < max_occurrences else '...' - for i, occurrence in enumerate(next_occurrences)), - ]) + ) + ) + extra.append( + [ + "Next occurrences:", + *( + str(occurrence) if i < max_occurrences else "..." + for i, occurrence in enumerate(next_occurrences) + ), + ] + ) _send_email( todo_id=todo_id, config=config, @@ -318,5 +353,5 @@ def main( _save_state(args_parsed.state, state) -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/salt/file/todo/todo_test.py b/salt/file/todo/todo_test.py index ab7bc65..047cbc0 100644 --- a/salt/file/todo/todo_test.py +++ b/salt/file/todo/todo_test.py @@ -45,10 +45,12 @@ class TodoTest(parameterized.TestCase): def setUp(self): super().setUp() - self._subprocess_run = mock.create_autospec(subprocess.run, - spec_set=True) + self._subprocess_run = mock.create_autospec( + subprocess.run, spec_set=True + ) self._email_parser = email.parser.BytesParser( - policy=email.policy.default) + policy=email.policy.default + ) def _main( self, @@ -56,35 +58,39 @@ def _main( config: Any, state: Any = None, max_occurrences: int = 10, - ) ->Any: + ) -> Any: tempdir = self.create_tempdir() if config is not None: - tempdir.create_file('config', json.dumps(config)) + tempdir.create_file("config", json.dumps(config)) if state is not None: - tempdir.create_file('state', json.dumps(state)) + tempdir.create_file("state", json.dumps(state)) todo.main( ( - f'--config={tempdir.full_path}/config', - f'--state={tempdir.full_path}/state', - f'--max-occurrences={max_occurrences}', + f"--config={tempdir.full_path}/config", + f"--state={tempdir.full_path}/state", + f"--max-occurrences={max_occurrences}", ), subprocess_run=self._subprocess_run, ) - with open(f'{tempdir.full_path}/state', mode='rb') as state_file: + with open(f"{tempdir.full_path}/state", mode="rb") as state_file: return json.load(state_file) def _assert_messages_sent(self, *expected_messages: _Message): self.assertLen(self._subprocess_run.mock_calls, len(expected_messages)) - for run_call, expected_message in zip(self._subprocess_run.mock_calls, - expected_messages): + for run_call, expected_message in zip( + self._subprocess_run.mock_calls, expected_messages + ): self.assertEqual( - mock.call(('/usr/sbin/sendmail', '-i', '-t'), - check=True, - input=mock.ANY), + mock.call( + ("/usr/sbin/sendmail", "-i", "-t"), + check=True, + input=mock.ANY, + ), run_call, ) actual_message = self._email_parser.parsebytes( - run_call.kwargs['input']) + run_call.kwargs["input"] + ) self.assertEqual( { header: tuple(values) @@ -96,211 +102,281 @@ def _assert_messages_sent(self, *expected_messages: _Message): }, ) if expected_message.parts: - self.assertEqual('multipart/mixed', - actual_message.get_content_type()) + self.assertEqual( + "multipart/mixed", actual_message.get_content_type() + ) actual_parts = tuple(actual_message.iter_parts()) self.assertLen(actual_parts, len(expected_message.parts)) - for actual_part, expected_part in zip(actual_parts, - expected_message.parts): - self.assertEqual('text/plain', - actual_part.get_content_type()) - self.assertEqual('inline', - actual_part.get_content_disposition()) - self.assertEqual(expected_part.filename, - actual_part.get_filename()) - self.assertEqual(expected_part.content, - actual_part.get_content()) + for actual_part, expected_part in zip( + actual_parts, expected_message.parts + ): + self.assertEqual( + "text/plain", actual_part.get_content_type() + ) + self.assertEqual( + "inline", actual_part.get_content_disposition() + ) + self.assertEqual( + expected_part.filename, actual_part.get_filename() + ) + self.assertEqual( + expected_part.content, actual_part.get_content() + ) else: self.assertEmpty(actual_message.get_content()) @parameterized.named_parameters( dict( - testcase_name='config_missing', + testcase_name="config_missing", config=None, error_class=FileNotFoundError, ), dict( - testcase_name='config_unexpected_group_key', + testcase_name="config_unexpected_group_key", config=dict(some_group=dict(todos={}, kumquat={})), error_class=ValueError, - error_regex='Unexpected group config keys:.*kumquat', + error_regex="Unexpected group config keys:.*kumquat", ), dict( - testcase_name='config_missing_required_fields', + testcase_name="config_missing_required_fields", config=dict(some_group=dict(todos=dict(some_todo={}))), error_class=TypeError, - error_regex='summary', + error_regex="summary", ), dict( - testcase_name='config_unexpected_key', - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers={}, - summary='foo', - kumquat='', - )))), + testcase_name="config_unexpected_key", + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers={}, + summary="foo", + kumquat="", + ) + ) + ) + ), error_class=TypeError, - error_regex='kumquat', + error_regex="kumquat", ), dict( - testcase_name='config_invalid_timezone', - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers={}, - summary='foo', - timezone='invalid timezone', - start='20010101T000000Z', - )))), + testcase_name="config_invalid_timezone", + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers={}, + summary="foo", + timezone="invalid timezone", + start="20010101T000000Z", + ) + ) + ) + ), error_class=ValueError, - error_regex='Invalid timezone', + error_regex="Invalid timezone", ), dict( - testcase_name='config_invalid_start', - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers={}, - summary='foo', - start='invalid datetime', - )))), + testcase_name="config_invalid_start", + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers={}, + summary="foo", + start="invalid datetime", + ) + ) + ) + ), error_class=ValueError, - error_regex='Invalid start', + error_regex="Invalid start", ), dict( - testcase_name='config_invalid_recurrence_rule', - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers={}, - summary='foo', - start='20010101T000000Z', - recurrence_rule='invalid recurrence rule', - )))), + testcase_name="config_invalid_recurrence_rule", + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers={}, + summary="foo", + start="20010101T000000Z", + recurrence_rule="invalid recurrence rule", + ) + ) + ) + ), error_class=ValueError, - error_regex='Invalid recurrence_rule', + error_regex="Invalid recurrence_rule", ), dict( - testcase_name='config_invalid_recurrence_rule_is_rruleset', - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers={}, - summary='foo', - start='20010101T000000Z', - recurrence_rule='RRULE:FREQ=DAILY\nRRULE:FREQ=DAILY', - )))), + testcase_name="config_invalid_recurrence_rule_is_rruleset", + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers={}, + summary="foo", + start="20010101T000000Z", + recurrence_rule="RRULE:FREQ=DAILY\nRRULE:FREQ=DAILY", + ) + ) + ) + ), error_class=ValueError, - error_regex='not an rruleset', + error_regex="not an rruleset", ), dict( - testcase_name='state_unexpected_key', + testcase_name="state_unexpected_key", config={}, - state=dict(some_todo=dict(kumquat='')), + state=dict(some_todo=dict(kumquat="")), error_class=TypeError, - error_regex='kumquat', + error_regex="kumquat", ), dict( - testcase_name='state_invalid_last_sent', + testcase_name="state_invalid_last_sent", config={}, - state={'some_group.some_todo': dict(last_sent='invalid datetime')}, + state={"some_group.some_todo": dict(last_sent="invalid datetime")}, error_class=ValueError, - error_regex='Invalid last_sent', + error_regex="Invalid last_sent", ), dict( - testcase_name='state_last_sent_in_future', + testcase_name="state_last_sent_in_future", config={}, - state={'some_group.some_todo': dict(last_sent='20010101T000000Z')}, + state={"some_group.some_todo": dict(last_sent="20010101T000000Z")}, error_class=RuntimeError, - error_regex='in the future', + error_regex="in the future", ), ) - @freezegun.freeze_time('2000-01-01') + @freezegun.freeze_time("2000-01-01") def test_error( self, *, config: Any, state: Any = None, error_class: Type[Exception], - error_regex: str = '', + error_regex: str = "", ): with self.assertRaisesRegex(error_class, error_regex): self._main(config=config, state=state) @parameterized.named_parameters( dict( - testcase_name='uses_default', - group_extra=dict(defaults=dict(email_headers=dict( - To='alice@example.com'))), + testcase_name="uses_default", + group_extra=dict( + defaults=dict(email_headers=dict(To="alice@example.com")) + ), todo_extra={}, ), dict( - testcase_name='overrides_default', - group_extra=dict(defaults=dict(email_headers=dict( - To='bob@example.com'))), - todo_extra=dict(email_headers=dict(To='alice@example.com')), + testcase_name="overrides_default", + group_extra=dict( + defaults=dict(email_headers=dict(To="bob@example.com")) + ), + todo_extra=dict(email_headers=dict(To="alice@example.com")), ), ) - @freezegun.freeze_time('2000-01-01') + @freezegun.freeze_time("2000-01-01") def test_config_defaults(self, group_extra: Any, todo_extra: Any): - self._main(config=dict(some_group=dict( - **group_extra, - todos=dict(some_todo=dict( - **todo_extra, - summary='apple', - start='20000101T000000Z', - )), - ))) + self._main( + config=dict( + some_group=dict( + **group_extra, + todos=dict( + some_todo=dict( + **todo_extra, + summary="apple", + start="20000101T000000Z", + ) + ), + ) + ) + ) self._assert_messages_sent( - _Message(headers={'To': ('alice@example.com',)}, parts=())) + _Message(headers={"To": ("alice@example.com",)}, parts=()) + ) @parameterized.named_parameters( dict( - testcase_name='empty_config_no_state', + testcase_name="empty_config_no_state", initial_state=None, config={}, ), dict( - testcase_name='irrelevant_state', - initial_state={'unknown-todo': dict(last_sent='19990203T010203Z')}, + testcase_name="irrelevant_state", + initial_state={"unknown-todo": dict(last_sent="19990203T010203Z")}, config={}, ), dict( - testcase_name='start_in_future', - initial_state={'some_group.some_todo': dict(last_sent=None)}, - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers={}, - summary='foo', - start='20010101T000000Z', - )))), + testcase_name="start_in_future", + initial_state={"some_group.some_todo": dict(last_sent=None)}, + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers={}, + summary="foo", + start="20010101T000000Z", + ) + ) + ) + ), ), dict( - testcase_name='start_in_future_but_previously_sent', + testcase_name="start_in_future_but_previously_sent", initial_state={ - 'some_group.some_todo': dict(last_sent='19990101T000000Z'), + "some_group.some_todo": dict(last_sent="19990101T000000Z"), }, - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers={}, - summary='foo', - start='20010101T000000Z', - )))), + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers={}, + summary="foo", + start="20010101T000000Z", + ) + ) + ) + ), ), dict( - testcase_name='one_time_todo_already_sent', + testcase_name="one_time_todo_already_sent", initial_state={ - 'some_group.some_todo': dict(last_sent='19990101T000000Z'), + "some_group.some_todo": dict(last_sent="19990101T000000Z"), }, - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers={}, - summary='foo', - start='19990101T000000Z', - )))), + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers={}, + summary="foo", + start="19990101T000000Z", + ) + ) + ) + ), ), dict( - testcase_name='between_occurrences', + testcase_name="between_occurrences", initial_state={ - 'some_group.some_todo': dict(last_sent='19991231T120000Z'), + "some_group.some_todo": dict(last_sent="19991231T120000Z"), }, - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers={}, - summary='apple', - start='19990101T120000Z', - recurrence_rule='FREQ=DAILY', - )))), - )) - @freezegun.freeze_time('2000-01-01') + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers={}, + summary="apple", + start="19990101T120000Z", + recurrence_rule="FREQ=DAILY", + ) + ) + ) + ), + ), + ) + @freezegun.freeze_time("2000-01-01") def test_nothing_to_send( self, initial_state: Any, @@ -309,25 +385,27 @@ def test_nothing_to_send( new_state = self._main(config=config, state=initial_state) self._subprocess_run.assert_not_called() - self.assertEqual({} if initial_state is None else initial_state, - new_state) + self.assertEqual( + {} if initial_state is None else initial_state, new_state + ) @parameterized.product( ( dict(description=None, expected_parts=()), dict( - description='orange', - expected_parts=(_MessagePart(filename=None, - content='orange\n'),), + description="orange", + expected_parts=( + _MessagePart(filename=None, content="orange\n"), + ), ), ), ( - dict(timezone='UTC', start='20000101T000000Z'), - dict(timezone='America/New_York', start='19991231T000000'), + dict(timezone="UTC", start="20000101T000000Z"), + dict(timezone="America/New_York", start="19991231T000000"), ), - last_sent=(None, '19990101T000000Z'), + last_sent=(None, "19990101T000000Z"), ) - @freezegun.freeze_time('2000-01-01') + @freezegun.freeze_time("2000-01-01") def test_sends_one_time_todo( self, last_sent: Optional[str], @@ -337,41 +415,49 @@ def test_sends_one_time_todo( expected_parts: Sequence[_MessagePart], ): new_state = self._main( - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers=dict(To='alice@example.com'), - summary='apple', - description=description, - timezone=timezone, - start=start, - )))), - state={'some_group.some_todo': dict(last_sent=last_sent)}, + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers=dict(To="alice@example.com"), + summary="apple", + description=description, + timezone=timezone, + start=start, + ) + ) + ) + ), + state={"some_group.some_todo": dict(last_sent=last_sent)}, ) self._assert_messages_sent( _Message( headers={ - 'To': ('alice@example.com',), - 'Subject': ('apple',), - 'Todo-Id': ('some_group.some_todo',), - 'Todo-Summary': ('apple',), - 'Todo-Timezone': (timezone,), - 'Todo-Start': (start,), - 'Todo-Recurrence-Rule': (), + "To": ("alice@example.com",), + "Subject": ("apple",), + "Todo-Id": ("some_group.some_todo",), + "Todo-Summary": ("apple",), + "Todo-Timezone": (timezone,), + "Todo-Start": (start,), + "Todo-Recurrence-Rule": (), }, parts=expected_parts, - )) + ) + ) self.assertEqual( - {'some_group.some_todo': dict(last_sent='20000101T000000Z')}, + {"some_group.some_todo": dict(last_sent="20000101T000000Z")}, new_state, ) @parameterized.named_parameters( dict( - testcase_name='one_at_start', - start='20000101T000000', + testcase_name="one_at_start", + start="20000101T000000", last_sent=None, - expected_subject='apple', - expected_extra_info=textwrap.dedent("""\ + expected_subject="apple", + expected_extra_info=textwrap.dedent( + """\ Occurrences included in this email: 2000-01-01 00:00:00-05:00 @@ -380,14 +466,16 @@ def test_sends_one_time_todo( 2000-01-03 00:00:00-05:00 2000-01-04 00:00:00-05:00 ... - """), + """ + ), ), dict( - testcase_name='one_after_last_sent', - start='19990101T000000', - last_sent='19991231T120000Z', - expected_subject='apple', - expected_extra_info=textwrap.dedent("""\ + testcase_name="one_after_last_sent", + start="19990101T000000", + last_sent="19991231T120000Z", + expected_subject="apple", + expected_extra_info=textwrap.dedent( + """\ Occurrences included in this email: 2000-01-01 00:00:00-05:00 @@ -396,14 +484,16 @@ def test_sends_one_time_todo( 2000-01-03 00:00:00-05:00 2000-01-04 00:00:00-05:00 ... - """), + """ + ), ), dict( - testcase_name='more_than_max', - start='19990101T000000', + testcase_name="more_than_max", + start="19990101T000000", last_sent=None, - expected_subject='apple (x3+)', - expected_extra_info=textwrap.dedent("""\ + expected_subject="apple (x3+)", + expected_extra_info=textwrap.dedent( + """\ Occurrences included in this email: 1999-01-01 00:00:00-05:00 1999-01-02 00:00:00-05:00 @@ -415,14 +505,16 @@ def test_sends_one_time_todo( 2000-01-03 00:00:00-05:00 2000-01-04 00:00:00-05:00 ... - """), + """ + ), ), dict( - testcase_name='max', - start='19991230T000000', + testcase_name="max", + start="19991230T000000", last_sent=None, - expected_subject='apple (x3)', - expected_extra_info=textwrap.dedent("""\ + expected_subject="apple (x3)", + expected_extra_info=textwrap.dedent( + """\ Occurrences included in this email: 1999-12-30 00:00:00-05:00 1999-12-31 00:00:00-05:00 @@ -433,10 +525,11 @@ def test_sends_one_time_todo( 2000-01-03 00:00:00-05:00 2000-01-04 00:00:00-05:00 ... - """), + """ + ), ), ) - @freezegun.freeze_time('2000-01-01 12:00:00') + @freezegun.freeze_time("2000-01-01 12:00:00") def test_sends_recurring_todo( self, start: str, @@ -445,38 +538,47 @@ def test_sends_recurring_todo( expected_extra_info: str, ): new_state = self._main( - config=dict(some_group=dict(todos=dict(some_todo=dict( - email_headers=dict(To='alice@example.com'), - summary='apple', - timezone='America/New_York', - start=start, - recurrence_rule='FREQ=DAILY', - )))), - state={'some_group.some_todo': dict(last_sent=last_sent)}, + config=dict( + some_group=dict( + todos=dict( + some_todo=dict( + email_headers=dict(To="alice@example.com"), + summary="apple", + timezone="America/New_York", + start=start, + recurrence_rule="FREQ=DAILY", + ) + ) + ) + ), + state={"some_group.some_todo": dict(last_sent=last_sent)}, max_occurrences=3, ) self._assert_messages_sent( _Message( headers={ - 'To': ('alice@example.com',), - 'Subject': (expected_subject,), - 'Todo-Id': ('some_group.some_todo',), - 'Todo-Summary': ('apple',), - 'Todo-Timezone': ('America/New_York',), - 'Todo-Start': (start,), - 'Todo-Recurrence-Rule': ('FREQ=DAILY',), + "To": ("alice@example.com",), + "Subject": (expected_subject,), + "Todo-Id": ("some_group.some_todo",), + "Todo-Summary": ("apple",), + "Todo-Timezone": ("America/New_York",), + "Todo-Start": (start,), + "Todo-Recurrence-Rule": ("FREQ=DAILY",), }, - parts=(_MessagePart( - filename='extra-information', - content=expected_extra_info, - ),), - )) + parts=( + _MessagePart( + filename="extra-information", + content=expected_extra_info, + ), + ), + ) + ) self.assertEqual( - {'some_group.some_todo': dict(last_sent='20000101T120000Z')}, + {"some_group.some_todo": dict(last_sent="20000101T120000Z")}, new_state, ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/salt/file/uptime/uptime_warning.py b/salt/file/uptime/uptime_warning.py index 2210268..f91769c 100644 --- a/salt/file/uptime/uptime_warning.py +++ b/salt/file/uptime/uptime_warning.py @@ -21,12 +21,12 @@ def main(): # https://en.wikipedia.org/wiki/Uptime#Using_/proc/uptime - with open('/proc/uptime') as uptime_file: + with open("/proc/uptime") as uptime_file: uptime_seconds, _ = uptime_file.read().split() uptime = datetime.timedelta(seconds=float(uptime_seconds)) if uptime > _WARN_AFTER: - print(f'Uptime: {uptime}') + print(f"Uptime: {uptime}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/salt/file/xmpp/dump.py b/salt/file/xmpp/dump.py index d7396de..a50dcfb 100644 --- a/salt/file/xmpp/dump.py +++ b/salt/file/xmpp/dump.py @@ -29,18 +29,19 @@ def _wait_and_check(popen: subprocess.Popen[Any]) -> None: if popen.wait() != 0: - raise subprocess.CalledProcessError(returncode=popen.returncode, - cmd=popen.args) + raise subprocess.CalledProcessError( + returncode=popen.returncode, cmd=popen.args + ) def _service_property(name: str) -> str: return subprocess.run( ( - 'systemctl', - 'show', - f'--property={name}', - '--value', - 'ejabberd.service', + "systemctl", + "show", + f"--property={name}", + "--value", + "ejabberd.service", ), stdout=subprocess.PIPE, check=True, @@ -49,18 +50,18 @@ def _service_property(name: str) -> str: def _nsenter() -> Sequence[str]: - pid = _service_property('MainPID') - uid = _service_property('UID') - gid = _service_property('GID') - if pid == '0' or uid == '[not set]' or gid == '[not set]': - raise RuntimeError('ejabberd is not running') + pid = _service_property("MainPID") + uid = _service_property("UID") + gid = _service_property("GID") + if pid == "0" or uid == "[not set]" or gid == "[not set]": + raise RuntimeError("ejabberd is not running") return ( - 'nsenter', - f'--target={pid}', - '--mount', - f'--setuid={uid}', - f'--setgid={gid}', - '--', + "nsenter", + f"--target={pid}", + "--mount", + f"--setuid={uid}", + f"--setgid={gid}", + "--", ) @@ -72,22 +73,22 @@ def _ejabberd_tempfile( ) -> Generator[str, None, None]: with contextlib.ExitStack() as exit_stack: tempfile_ = subprocess.run( - (*nsenter, 'mktemp'), + (*nsenter, "mktemp"), stdout=subprocess.PIPE, check=True, text=True, ).stdout.rstrip() exit_stack.callback( subprocess.run, - (*nsenter, 'rm', tempfile_), + (*nsenter, "rm", tempfile_), check=True, ) yield tempfile_ cat = subprocess.Popen( - (*nsenter, 'cat', tempfile_), + (*nsenter, "cat", tempfile_), stdout=subprocess.PIPE, ) - with open(copy_to, mode='xb') as copy_to_file: + with open(copy_to, mode="xb") as copy_to_file: shutil.copyfileobj(cast(IO[bytes], cat.stdout), copy_to_file) _wait_and_check(cat) @@ -100,25 +101,25 @@ def _ejabberd_tempdir( ) -> Generator[str, None, None]: with contextlib.ExitStack() as exit_stack: tempdir = subprocess.run( - (*nsenter, 'mktemp', '-d'), + (*nsenter, "mktemp", "-d"), stdout=subprocess.PIPE, check=True, text=True, ).stdout.rstrip() exit_stack.callback( subprocess.run, - (*nsenter, 'rm', '-rf', tempdir), + (*nsenter, "rm", "-rf", tempdir), check=True, ) yield tempdir tar_create = subprocess.Popen( ( *nsenter, - 'tar', - '--create', - '--file=-', - f'--directory={tempdir}', - '.', + "tar", + "--create", + "--file=-", + f"--directory={tempdir}", + ".", ), stdout=subprocess.PIPE, ) @@ -126,33 +127,37 @@ def _ejabberd_tempdir( os.mkdir(copy_to) tar_extract = subprocess.Popen( ( - 'tar', - '--extract', - '--file=-', - f'--directory={copy_to}', + "tar", + "--extract", + "--file=-", + f"--directory={copy_to}", ), stdin=subprocess.PIPE, ) # TODO(https://github.com/python/mypy/issues/15031): Remove type ignore. shutil.copyfileobj( # type: ignore cast(IO[bytes], tar_create.stdout), - cast(IO[bytes], tar_extract.stdin)) + cast(IO[bytes], tar_extract.stdin), + ) _wait_and_check(tar_create) _wait_and_check(tar_extract) def main() -> None: nsenter = _nsenter() - with _ejabberd_tempfile(nsenter=nsenter, - copy_to='ejabberd.dump') as dump_filename: - subprocess.run(('ejabberdctl', 'dump', dump_filename), check=True) + with _ejabberd_tempfile( + nsenter=nsenter, copy_to="ejabberd.dump" + ) as dump_filename: + subprocess.run(("ejabberdctl", "dump", dump_filename), check=True) # NOTE: This doesn't actually generate useful data, see # https://github.com/processone/ejabberd/issues/3705 - with _ejabberd_tempdir(nsenter=nsenter, - copy_to='ejabberd.piefxis') as dump_dirname: - subprocess.run(('ejabberdctl', 'export_piefxis', dump_dirname), - check=True) + with _ejabberd_tempdir( + nsenter=nsenter, copy_to="ejabberd.piefxis" + ) as dump_dirname: + subprocess.run( + ("ejabberdctl", "export_piefxis", dump_dirname), check=True + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/salt/file/xmpp/ejabberd_authentication.py b/salt/file/xmpp/ejabberd_authentication.py index 4a236d2..4aa0711 100644 --- a/salt/file/xmpp/ejabberd_authentication.py +++ b/salt/file/xmpp/ejabberd_authentication.py @@ -63,18 +63,19 @@ def _args(): parser = argparse.ArgumentParser( - description='External authentication helper for ejabberd.') + description="External authentication helper for ejabberd." + ) parser.add_argument( - '--config', + "--config", type=pathlib.Path, required=True, - help='Absolute path to the authentication config file.', + help="Absolute path to the authentication config file.", ) parser.add_argument( - '--max-passwords-per-user', + "--max-passwords-per-user", default=25, type=int, - help='Max number of passwords per user.', + help="Max number of passwords per user.", ) return parser.parse_args() @@ -85,12 +86,12 @@ def _config( max_passwords_per_user: int, ) -> _Config: raw_config = collections.defaultdict(list) - with config_path.open(mode='rb') as config_file: + with config_path.open(mode="rb") as config_file: for line in config_file: - if not line.strip() or line.lstrip().startswith(b'#'): + if not line.strip() or line.lstrip().startswith(b"#"): continue - user, server, crypted_password = line.rstrip(b'\n').split(b':') - raw_config[(user, server)].append(crypted_password.decode('utf-8')) + user, server, crypted_password = line.rstrip(b"\n").split(b":") + raw_config[(user, server)].append(crypted_password.decode("utf-8")) # collections.defaultdict makes it easy to accidentally add new keys, which # is useful above, but a potential security risk after parsing is done. # E.g., if some code did `config[(user, server)]` with untrusted input, that @@ -100,12 +101,14 @@ def _config( config = {} for key, crypted_passwords in raw_config.items(): if len(crypted_passwords) > max_passwords_per_user: - raise ValueError(f'{key!r} has too many passwords') + raise ValueError(f"{key!r} has too many passwords") # Ensure exactly max_passwords_per_user entries by repeating the # entries. This makes it harder to figure out how many passwords a user # has by measuring how long it takes to test a password. - config[key] = tuple(crypted_passwords[i % len(crypted_passwords)] - for i in range(max_passwords_per_user)) + config[key] = tuple( + crypted_passwords[i % len(crypted_passwords)] + for i in range(max_passwords_per_user) + ) return config @@ -120,24 +123,24 @@ def _read_operations() -> Iterable[tuple[bytes, bytes]]: if not length_bytes: return elif len(length_bytes) != 2: - raise ValueError(f'Expected 2 bytes, got {len(length_bytes)}') - (length,) = struct.unpack('!H', length_bytes) + raise ValueError(f"Expected 2 bytes, got {len(length_bytes)}") + (length,) = struct.unpack("!H", length_bytes) value_bytes = sys.stdin.buffer.read(length) if len(value_bytes) != length: - raise ValueError(f'Expected {length} bytes, got {len(value_bytes)}') - operation, _, args = value_bytes.partition(b':') + raise ValueError(f"Expected {length} bytes, got {len(value_bytes)}") + operation, _, args = value_bytes.partition(b":") yield operation, args def _respond(response: _Response) -> None: - sys.stdout.buffer.write(struct.pack('!HH', 2, response)) + sys.stdout.buffer.write(struct.pack("!HH", 2, response)) sys.stdout.buffer.flush() def _auth(operation_args: bytes, *, config: _Config) -> _Response: - user, server, password = operation_args.split(b':', maxsplit=2) + user, server, password = operation_args.split(b":", maxsplit=2) try: - password_str = password.decode('utf-8') + password_str = password.decode("utf-8") except UnicodeDecodeError: return _Response.FAILURE crypted_passwords = config.get((user, server)) @@ -164,7 +167,7 @@ def _auth(operation_args: bytes, *, config: _Config) -> _Response: def _isuser(operation_args: bytes, *, config: _Config) -> _Response: - user, server = operation_args.split(b':') + user, server = operation_args.split(b":") return _Response.SUCCESS if (user, server) in config else _Response.FAILURE @@ -175,13 +178,13 @@ def main() -> None: max_passwords_per_user=args.max_passwords_per_user, ) for operation, operation_args in _read_operations(): - if operation == b'auth': + if operation == b"auth": _respond(_auth(operation_args, config=config)) - elif operation == b'isuser': + elif operation == b"isuser": _respond(_isuser(operation_args, config=config)) else: _respond(_Response.FAILURE) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/salt/file/xmpp/ejabberd_authentication_test.py b/salt/file/xmpp/ejabberd_authentication_test.py index 5cf9a4d..5d61a71 100644 --- a/salt/file/xmpp/ejabberd_authentication_test.py +++ b/salt/file/xmpp/ejabberd_authentication_test.py @@ -34,11 +34,11 @@ class EjabberdAuthenticationTest(unittest.TestCase): def _main( self, *, - config: str = '', + config: str = "", max_passwords_per_user: int = 25, ) -> Generator[tuple[IO[bytes], IO[bytes]], None, None]: """Runs the main program, yielding its (stdin, stdout).""" - with tempfile.NamedTemporaryFile(mode='w+t') as config_file: + with tempfile.NamedTemporaryFile(mode="w+t") as config_file: config_file.write(config) config_file.flush() main = subprocess.Popen( @@ -46,9 +46,11 @@ def _main( sys.executable, str( pathlib.Path(__file__).parent.joinpath( - 'ejabberd_authentication.py')), - f'--config={config_file.name}', - f'--max-passwords-per-user={max_passwords_per_user}', + "ejabberd_authentication.py" + ) + ), + f"--config={config_file.name}", + f"--max-passwords-per-user={max_passwords_per_user}", ), stdin=subprocess.PIPE, stdout=subprocess.PIPE, @@ -65,122 +67,133 @@ def _main( stderr_io.close() if main.returncode != 0 or stdout or stderr: raise RuntimeError( - f'Main returned {main.returncode} with unread stdout ' - f'{stdout!r} and stderr:\n{stderr.decode()}') + f"Main returned {main.returncode} with unread stdout " + f"{stdout!r} and stderr:\n{stderr.decode()}" + ) def _assert_failure(self, stdout: IO[bytes]) -> None: - self.assertEqual(b'\x00\x02\x00\x00', stdout.read(4)) + self.assertEqual(b"\x00\x02\x00\x00", stdout.read(4)) def _assert_success(self, stdout: IO[bytes]) -> None: - self.assertEqual(b'\x00\x02\x00\x01', stdout.read(4)) + self.assertEqual(b"\x00\x02\x00\x01", stdout.read(4)) def test_ignores_whitespace_and_comment_lines(self): - with self._main(config=('\n' - ' \t\n' - '# this is a comment\n' - ' \t# also a comment\n' - 'alice:example.com:!\n')) as (stdin, stdout): - stdin.write(b'\x00\x18isuser:alice:example.com') + with self._main( + config=( + "\n" + " \t\n" + "# this is a comment\n" + " \t# also a comment\n" + "alice:example.com:!\n" + ) + ) as (stdin, stdout): + stdin.write(b"\x00\x18isuser:alice:example.com") stdin.close() self._assert_success(stdout) def test_too_many_passwords_error(self): - with self.assertRaisesRegex(RuntimeError, 'has too many passwords'): + with self.assertRaisesRegex(RuntimeError, "has too many passwords"): with self._main( - config='alice:example.com:!\n' * 3, - max_passwords_per_user=2, + config="alice:example.com:!\n" * 3, + max_passwords_per_user=2, ) as (stdin, _): stdin.close() def test_incomplete_length_error(self): - with self.assertRaisesRegex(RuntimeError, 'Expected 2 bytes, got 1'): + with self.assertRaisesRegex(RuntimeError, "Expected 2 bytes, got 1"): with self._main() as (stdin, _): - stdin.write(b'\x00') + stdin.write(b"\x00") stdin.close() def test_incomplete_value_error(self): - with self.assertRaisesRegex(RuntimeError, 'Expected 3 bytes, got 2'): + with self.assertRaisesRegex(RuntimeError, "Expected 3 bytes, got 2"): with self._main() as (stdin, _): - stdin.write(b'\x00\x03fo') + stdin.write(b"\x00\x03fo") stdin.close() def test_auth_wrong_arg_count_error(self): - with self.assertRaisesRegex(RuntimeError, - r'values to unpack \(expected 3, got 1\)'): + with self.assertRaisesRegex( + RuntimeError, r"values to unpack \(expected 3, got 1\)" + ): with self._main() as (stdin, _): - stdin.write(b'\x00\x08auth:foo') + stdin.write(b"\x00\x08auth:foo") stdin.close() def test_auth_not_unicode_failure(self): - with self._main(config='alice:example.com:!\n') as (stdin, stdout): - stdin.write(b'\x00\x18auth:alice:example.com:\xff') + with self._main(config="alice:example.com:!\n") as (stdin, stdout): + stdin.write(b"\x00\x18auth:alice:example.com:\xff") stdin.close() self._assert_failure(stdout) def test_auth_not_a_user_failure(self): with self._main() as (stdin, stdout): - stdin.write(b'\x00\x1aauth:alice:example.com:foo') + stdin.write(b"\x00\x1aauth:alice:example.com:foo") stdin.close() self._assert_failure(stdout) def test_auth_wrong_password_failure(self): with self._main( - config=(f'alice:example.com:{_crypt("foo")}\n' - f'bob:example.com:{_crypt("bar")}\n'), # + config=( + f'alice:example.com:{_crypt("foo")}\n' + f'bob:example.com:{_crypt("bar")}\n' + ), # ) as (stdin, stdout): - stdin.write(b'\x00\x18auth:bob:example.com:foo') + stdin.write(b"\x00\x18auth:bob:example.com:foo") stdin.close() self._assert_failure(stdout) def test_auth_success(self): with self._main( - config=(f'alice:example.com:{_crypt("bar")}\n' - f'alice:example.com:{_crypt("foo")}\n'), # + config=( + f'alice:example.com:{_crypt("bar")}\n' + f'alice:example.com:{_crypt("foo")}\n' + ), # ) as (stdin, stdout): - stdin.write(b'\x00\x1aauth:alice:example.com:foo') + stdin.write(b"\x00\x1aauth:alice:example.com:foo") stdin.close() self._assert_success(stdout) def test_auth_success_colon_in_password(self): # https://github.com/processone/ejabberd/issues/3677 with self._main( - config=f'alice:example.com:{_crypt("foo:bar")}\n', # + config=f'alice:example.com:{_crypt("foo:bar")}\n', # ) as (stdin, stdout): - stdin.write(b'\x00\x1eauth:alice:example.com:foo:bar') + stdin.write(b"\x00\x1eauth:alice:example.com:foo:bar") stdin.close() self._assert_success(stdout) def test_isuser_wrong_arg_count_error(self): - with self.assertRaisesRegex(RuntimeError, - r'values to unpack \(expected 2, got 1\)'): + with self.assertRaisesRegex( + RuntimeError, r"values to unpack \(expected 2, got 1\)" + ): with self._main() as (stdin, _): - stdin.write(b'\x00\x0aisuser:foo') + stdin.write(b"\x00\x0aisuser:foo") stdin.close() def test_isuser_success(self): - with self._main(config='alice:example.com:!\n') as (stdin, stdout): - stdin.write(b'\x00\x18isuser:alice:example.com') + with self._main(config="alice:example.com:!\n") as (stdin, stdout): + stdin.write(b"\x00\x18isuser:alice:example.com") stdin.close() self._assert_success(stdout) def test_isuser_failure(self): with self._main() as (stdin, stdout): - stdin.write(b'\x00\x18isuser:alice:example.com') + stdin.write(b"\x00\x18isuser:alice:example.com") stdin.close() self._assert_failure(stdout) def test_empty_value_failure(self): with self._main() as (stdin, stdout): - stdin.write(b'\x00\x00') + stdin.write(b"\x00\x00") stdin.close() self._assert_failure(stdout) def test_unknown_operation_failure(self): with self._main() as (stdin, stdout): - stdin.write(b'\x00\x03foo') + stdin.write(b"\x00\x03foo") stdin.close() self._assert_failure(stdout) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()