From 8f735b6fce710486d9ae5f97643013b588dab847 Mon Sep 17 00:00:00 2001 From: Robert Kaussow Date: Sat, 29 Jul 2023 22:27:34 +0200 Subject: [PATCH] feat: add hashivault_unseal module --- plugins/module_utils/__init__.py | 0 plugins/module_utils/hashivault.py | 407 +++++++++++++++++++++++++++ plugins/modules/hashivault_unseal.py | 72 +++++ pyproject.toml | 1 + 4 files changed, 480 insertions(+) create mode 100644 plugins/module_utils/__init__.py create mode 100644 plugins/module_utils/hashivault.py create mode 100644 plugins/modules/hashivault_unseal.py diff --git a/plugins/module_utils/__init__.py b/plugins/module_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/module_utils/hashivault.py b/plugins/module_utils/hashivault.py new file mode 100644 index 0000000..80b1ebc --- /dev/null +++ b/plugins/module_utils/hashivault.py @@ -0,0 +1,407 @@ +"""Provide helper functions for Hashivault module.""" + +from __future__ import (absolute_import, division, print_function) + +__metaclass__ = type + +import os +import traceback + +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.basic import env_fallback + +HVAC_IMP_ERR = None +try: + import hvac + from hvac.exceptions import InvalidPath + HAS_HVAC = True +except ImportError: + HAS_HVAC = False + HVAC_IMP_ERR = traceback.format_exc() + + +def hashivault_argspec(): + return dict( + url=dict(required=False, default=os.environ.get("VAULT_ADDR", ""), type="str"), + ca_cert=dict(required=False, default=os.environ.get("VAULT_CACERT", ""), type="str"), + ca_path=dict(required=False, default=os.environ.get("VAULT_CAPATH", ""), type="str"), + client_cert=dict( + required=False, default=os.environ.get("VAULT_CLIENT_CERT", ""), type="str" + ), + client_key=dict( + required=False, + default=os.environ.get("VAULT_CLIENT_KEY", ""), + type="str", + no_log=True + ), + verify=dict( + required=False, + default=(not os.environ.get("VAULT_SKIP_VERIFY", "False")), + type="bool" + ), + authtype=dict( + required=False, + default=os.environ.get("VAULT_AUTHTYPE", "token"), + type="str", + choices=["token", "userpass", "github", "ldap", "approle"] + ), + login_mount_point=dict( + required=False, default=os.environ.get("VAULT_LOGIN_MOUNT_POINT", None), type="str" + ), + token=dict( + required=False, + fallback=(hashivault_default_token, ["VAULT_TOKEN"]), + type="str", + no_log=True + ), + username=dict(required=False, default=os.environ.get("VAULT_USER", ""), type="str"), + password=dict( + required=False, fallback=(env_fallback, ["VAULT_PASSWORD"]), type="str", no_log=True + ), + role_id=dict( + required=False, fallback=(env_fallback, ["VAULT_ROLE_ID"]), type="str", no_log=True + ), + secret_id=dict( + required=False, fallback=(env_fallback, ["VAULT_SECRET_ID"]), type="str", no_log=True + ), + aws_header=dict( + required=False, fallback=(env_fallback, ["VAULT_AWS_HEADER"]), type="str", no_log=True + ), + namespace=dict( + required=False, default=os.environ.get("VAULT_NAMESPACE", None), type="str" + ) + ) + + +def hashivault_init( + argument_spec, + supports_check_mode=False, + required_if=None, + required_together=None, + required_one_of=None, + mutually_exclusive=None +): + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=supports_check_mode, + required_if=required_if, + required_together=required_together, + required_one_of=required_one_of, + mutually_exclusive=mutually_exclusive + ) + + if not HAS_HVAC: + module.fail_json(msg=missing_required_lib("hvac"), exception=HVAC_IMP_ERR) + + module.no_log_values.discard("0") + module.no_log_values.discard(0) + module.no_log_values.discard("1") + module.no_log_values.discard(1) + module.no_log_values.discard(True) + module.no_log_values.discard(False) + module.no_log_values.discard("ttl") + + return module + + +def hashivault_client(params): + url = params.get("url") + ca_cert = params.get("ca_cert") + ca_path = params.get("ca_path") + client_cert = params.get("client_cert") + client_key = params.get("client_key") + cert = (client_cert, client_key) + check_verify = params.get("verify") + namespace = params.get("namespace", None) + + if check_verify == "" or check_verify: + if ca_cert: + verify = ca_cert + elif ca_path: + verify = ca_path + else: + verify = check_verify + else: + verify = check_verify + + return hvac.Client(url=url, cert=cert, verify=verify, namespace=namespace) + + +def hashivault_auth(client, params): + token = params.get("token") + authtype = params.get("authtype") + login_mount_point = params.get("login_mount_point", authtype) + if not login_mount_point: + login_mount_point = authtype + username = params.get("username") + password = params.get("password") + secret_id = params.get("secret_id") + role_id = params.get("role_id") + + if authtype == "github": + client.auth.github.login(token, mount_point=login_mount_point) + elif authtype == "userpass": + client.auth_userpass(username, password, mount_point=login_mount_point) + elif authtype == "ldap": + client.auth.ldap.login(username, password, mount_point=login_mount_point) + elif authtype == "approle": + client = AppRoleClient(client, role_id, secret_id, mount_point=login_mount_point) + elif authtype == "tls": + client.auth_tls() + else: + client.token = token + return client + + +def hashivault_auth_client(params): + client = hashivault_client(params) + return hashivault_auth(client, params) + + +def hashiwrapper(function): + + def wrapper(*args, **kwargs): + result = {"changed": False, "rc": 0} + result.update(function(*args, **kwargs)) + return result + + return wrapper + + +def hashivault_default_token(env): + """Get a default Vault token from an environment variable or a file.""" + envvar = env[0] + if envvar in os.environ: + return os.environ[envvar] + token_file = os.path.expanduser("~/.vault-token") + if os.path.exists(token_file): + with open(token_file) as f: + return f.read().strip() + return "" + + +@hashiwrapper +def hashivault_read(params): + result = {"changed": False, "rc": 0} + client = hashivault_auth_client(params) + version = params.get("version") + mount_point = params.get("mount_point") + secret = params.get("secret") + secret_version = params.get("secret_version") + + key = params.get("key") + default = params.get("default") + + if secret.startswith("/"): + secret = secret.lstrip("/") + mount_point = "" + + secret_path = f"{mount_point}/{secret}" if mount_point else secret + + try: + if version == 2: + response = client.secrets.kv.v2.read_secret_version( + secret, mount_point=mount_point, version=secret_version + ) + else: + response = client.secrets.kv.v1.read_secret(secret, mount_point=mount_point) + except InvalidPath: + response = None + except Exception as e: # noqa: BLE001 + result["rc"] = 1 + result["failed"] = True + error_string = f"{e.__class__.__name__}({e})" + result["msg"] = f"Error {error_string} reading {secret_path}" + return result + if not response: + if default is not None: + result["value"] = default + return result + result["rc"] = 1 + result["failed"] = True + result["msg"] = f"Secret {secret_path} is not in vault" + return result + if version == 2: + try: + data = response.get("data", {}) + data = data.get("data", {}) + except Exception: # noqa: BLE001 + data = str(response) + else: + data = response["data"] + lease_duration = response.get("lease_duration", None) + if lease_duration is not None: + result["lease_duration"] = lease_duration + lease_id = response.get("lease_id", None) + if lease_id is not None: + result["lease_id"] = lease_id + renewable = response.get("renewable", None) + if renewable is not None: + result["renewable"] = renewable + wrap_info = response.get("wrap_info", None) + if wrap_info is not None: + result["wrap_info"] = wrap_info + if key and key not in data: + if default is not None: + result["value"] = default + return result + result["rc"] = 1 + result["failed"] = True + result["msg"] = f"Key {key} is not in secret {secret_path}" + return result + value = data[key] if key else data + result["value"] = value + return result + + +class AppRoleClient: + """ + hvac.Client decorator generate and set a new approle token. + + This allows multiple calls to Vault without having to manually + generate and set a token on every Vault call. + """ + + def __init__(self, client, role_id, secret_id, mount_point): + object.__setattr__(self, "client", client) + object.__setattr__(self, "role_id", role_id) + object.__setattr__(self, "secret_id", secret_id) + object.__setattr__(self, "login_mount_point", mount_point) + + def __setattr__(self, name, val): + client = object.__getattribute__(self, "client") + client.__setattr__(name, val) + + def __getattribute__(self, name): + client = object.__getattribute__(self, "client") + attr = client.__getattribute__(name) + + role_id = object.__getattribute__(self, "role_id") + secret_id = object.__getattribute__(self, "secret_id") + login_mount_point = object.__getattribute__(self, "login_mount_point") + resp = client.auth_approle(role_id, secret_id=secret_id, mount_point=login_mount_point) + client.token = str(resp["auth"]["client_token"]) + return attr + + +def _compare_state(desired_state, current_state, ignore=None): + """ + Compare desired state to current state. + + Returns true if objects are equal. + + Recursively walks dict object to compare all keys. + + :param desired_state: The state user desires. + :param current_state: The state that currently exists. + :param ignore: Ignore these keys. + :type ignore: list + + :return: True if the states are the same. + :rtype: bool + """ + + if ignore is None: + ignore = [] + if (type(desired_state) is list): + if ((type(current_state) != list) or (len(desired_state) != len(current_state))): + return False + return set(desired_state) == set(current_state) + + if (type(desired_state) is dict): + if (type(current_state) != dict): + return False + + # iterate over dictionary keys + for key in desired_state: + if key in ignore: + continue + v = desired_state[key] + if ((key not in current_state) or (not _compare_state(v, current_state.get(key)))): + return False + return True + + # Lots of things get handled as strings in ansible that aren"t necessarily strings, + # can extend this list later. + if isinstance(desired_state, str) and isinstance(current_state, int): + current_state = str(current_state) + + return (desired_state == current_state) + + +def _convert_to_seconds(original_value): + try: + value = str(original_value) + seconds = 0 + if "h" in value: + ray = value.split("h") + seconds = int(ray.pop(0)) * 3600 + value = "".join(ray) + if "m" in value: + ray = value.split("m") + seconds += int(ray.pop(0)) * 60 + value = "".join(ray) + if value: + ray = value.split("s") + seconds += int(ray.pop(0)) + return seconds + except Exception: # noqa: BLE001 + pass + return original_value + + +def get_keys_updated(desired_state, current_state, ignore=None): + """ + + Return list of keys that have different values. + + Recursively walks dict object to compare all keys. + + :param desired_state: The state user desires. + :type desired_state: dict + :param current_state: The state that currently exists. + :type current_state: dict + :param ignore: Ignore these keys. + :type ignore: list + + :return: Different items + :rtype: list + """ + + if ignore is None: + ignore = [] + + differences = [] + for key in desired_state: + if key in ignore: + continue + if (key not in current_state): + differences.append(key) + continue + new_value = desired_state[key] + old_value = current_state[key] + if "ttl" in key and (_convert_to_seconds(old_value) != _convert_to_seconds(new_value)): + differences.append(key) + elif not _compare_state(new_value, old_value): + differences.append(key) + return differences + + +def is_state_changed(desired_state, current_state, ignore=None): # noqa: ARG001 + """ + Return list of keys that have different values. + + Recursively walks dict object to compare all keys. + + :param desired_state: The state user desires. + :type desired_state: dict + :param current_state: The state that currently exists. + :type current_state: dict + :param ignore: Ignore these keys. + :type ignore: list + + :return: Different + :rtype: bool + """ + return (len(get_keys_updated(desired_state, current_state)) > 0) diff --git a/plugins/modules/hashivault_unseal.py b/plugins/modules/hashivault_unseal.py new file mode 100644 index 0000000..938d1d2 --- /dev/null +++ b/plugins/modules/hashivault_unseal.py @@ -0,0 +1,72 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +"""Unseal Hashicorp Vault servers.""" + +from __future__ import (absolute_import, division, print_function) + +__metaclass__ = type + +ANSIBLE_METADATA = {"status": ["stableinterface"], "supported_by": "community", "version": "1.1"} + +DOCUMENTATION = """ +--- +module: hashivault_unseal +short_description: Hashicorp Vault unseal module. +version_added: 1.2.0 +description: + - "Module to unseal Hashicorp Vault." +options: + keys: + description: + - Vault key shard(s). + type: list + elements: str + required: true +author: + - ownCloud GmbH (@owncloud) +extends_documentation_fragment: + - owncloud.general.hashivault +""" + +EXAMPLES = """ +--- +- name: Unseal vault + hashivault_unseal: + keys: + - 26479cc0-54bc-4252-9c34-baca54aa5de7 + - 47f942e3-8525-4b44-ba2f-84a4ae81db7d + - 2ee9c868-4275-4836-8747-4f8fb7611aa0 + url: https://vault.example.com +""" + +from ansible_collections.owncloud.general.plugins.module_utils.hashivault import hashivault_argspec +from ansible_collections.owncloud.general.plugins.module_utils.hashivault import hashivault_client +from ansible_collections.owncloud.general.plugins.module_utils.hashivault import hashivault_init +from ansible_collections.owncloud.general.plugins.module_utils.hashivault import hashiwrapper + + +def main(): + argspec = hashivault_argspec() + argspec["keys"] = dict(required=True, type="list", elements="str", no_log=True) + module = hashivault_init(argspec) + result = hashivault_unseal(module.params) + if result.get("failed"): + module.fail_json(**result) + else: + module.exit_json(**result) + + +@hashiwrapper +def hashivault_unseal(params): + keys = params.get("keys") + client = hashivault_client(params) + if client.sys.is_sealed(): + return {"status": client.sys.submit_unseal_keys(keys), "changed": True} + + return {"changed": False} + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index b0f9067..485ba9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ exclude = [ ".cache", ".eggs", "env*", + ".venv", "iptables_raw.py", ] # Explanation of errors