xoxys.general/plugins/module_utils/hashivault.py
Robert Kaussow 9226ab6209
All checks were successful
continuous-integration/drone/push Build is passing
feat: add hashivault_unseal module (#5)
Reviewed-on: #5
Co-authored-by: Robert Kaussow <mail@thegeeklab.de>
Co-committed-by: Robert Kaussow <mail@thegeeklab.de>
2023-07-30 12:43:36 +02:00

408 lines
13 KiB
Python

"""Provide helper functions for Hashivault module."""
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import os
import traceback
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.basic import env_fallback
HVAC_IMP_ERR = None
try:
import hvac
from hvac.exceptions import InvalidPath
HAS_HVAC = True
except ImportError:
HAS_HVAC = False
HVAC_IMP_ERR = traceback.format_exc()
def hashivault_argspec():
return dict(
url=dict(required=False, default=os.environ.get("VAULT_ADDR", ""), type="str"),
ca_cert=dict(required=False, default=os.environ.get("VAULT_CACERT", ""), type="str"),
ca_path=dict(required=False, default=os.environ.get("VAULT_CAPATH", ""), type="str"),
client_cert=dict(
required=False, default=os.environ.get("VAULT_CLIENT_CERT", ""), type="str"
),
client_key=dict(
required=False,
default=os.environ.get("VAULT_CLIENT_KEY", ""),
type="str",
no_log=True
),
verify=dict(
required=False,
default=(not os.environ.get("VAULT_SKIP_VERIFY", "False")),
type="bool"
),
authtype=dict(
required=False,
default=os.environ.get("VAULT_AUTHTYPE", "token"),
type="str",
choices=["token", "userpass", "github", "ldap", "approle"]
),
login_mount_point=dict(
required=False, default=os.environ.get("VAULT_LOGIN_MOUNT_POINT", None), type="str"
),
token=dict(
required=False,
fallback=(hashivault_default_token, ["VAULT_TOKEN"]),
type="str",
no_log=True
),
username=dict(required=False, default=os.environ.get("VAULT_USER", ""), type="str"),
password=dict(
required=False, fallback=(env_fallback, ["VAULT_PASSWORD"]), type="str", no_log=True
),
role_id=dict(
required=False, fallback=(env_fallback, ["VAULT_ROLE_ID"]), type="str", no_log=True
),
secret_id=dict(
required=False, fallback=(env_fallback, ["VAULT_SECRET_ID"]), type="str", no_log=True
),
aws_header=dict(
required=False, fallback=(env_fallback, ["VAULT_AWS_HEADER"]), type="str", no_log=True
),
namespace=dict(
required=False, default=os.environ.get("VAULT_NAMESPACE", None), type="str"
)
)
def hashivault_init(
argument_spec,
supports_check_mode=False,
required_if=None,
required_together=None,
required_one_of=None,
mutually_exclusive=None
):
module = AnsibleModule(
argument_spec=argument_spec,
supports_check_mode=supports_check_mode,
required_if=required_if,
required_together=required_together,
required_one_of=required_one_of,
mutually_exclusive=mutually_exclusive
)
if not HAS_HVAC:
module.fail_json(msg=missing_required_lib("hvac"), exception=HVAC_IMP_ERR)
module.no_log_values.discard("0")
module.no_log_values.discard(0)
module.no_log_values.discard("1")
module.no_log_values.discard(1)
module.no_log_values.discard(True)
module.no_log_values.discard(False)
module.no_log_values.discard("ttl")
return module
def hashivault_client(params):
url = params.get("url")
ca_cert = params.get("ca_cert")
ca_path = params.get("ca_path")
client_cert = params.get("client_cert")
client_key = params.get("client_key")
cert = (client_cert, client_key)
check_verify = params.get("verify")
namespace = params.get("namespace", None)
if check_verify == "" or check_verify:
if ca_cert:
verify = ca_cert
elif ca_path:
verify = ca_path
else:
verify = check_verify
else:
verify = check_verify
return hvac.Client(url=url, cert=cert, verify=verify, namespace=namespace)
def hashivault_auth(client, params):
token = params.get("token")
authtype = params.get("authtype")
login_mount_point = params.get("login_mount_point", authtype)
if not login_mount_point:
login_mount_point = authtype
username = params.get("username")
password = params.get("password")
secret_id = params.get("secret_id")
role_id = params.get("role_id")
if authtype == "github":
client.auth.github.login(token, mount_point=login_mount_point)
elif authtype == "userpass":
client.auth_userpass(username, password, mount_point=login_mount_point)
elif authtype == "ldap":
client.auth.ldap.login(username, password, mount_point=login_mount_point)
elif authtype == "approle":
client = AppRoleClient(client, role_id, secret_id, mount_point=login_mount_point)
elif authtype == "tls":
client.auth_tls()
else:
client.token = token
return client
def hashivault_auth_client(params):
client = hashivault_client(params)
return hashivault_auth(client, params)
def hashiwrapper(function):
def wrapper(*args, **kwargs):
result = {"changed": False, "rc": 0}
result.update(function(*args, **kwargs))
return result
return wrapper
def hashivault_default_token(env):
"""Get a default Vault token from an environment variable or a file."""
envvar = env[0]
if envvar in os.environ:
return os.environ[envvar]
token_file = os.path.expanduser("~/.vault-token")
if os.path.exists(token_file):
with open(token_file) as f:
return f.read().strip()
return ""
@hashiwrapper
def hashivault_read(params):
result = {"changed": False, "rc": 0}
client = hashivault_auth_client(params)
version = params.get("version")
mount_point = params.get("mount_point")
secret = params.get("secret")
secret_version = params.get("secret_version")
key = params.get("key")
default = params.get("default")
if secret.startswith("/"):
secret = secret.lstrip("/")
mount_point = ""
secret_path = f"{mount_point}/{secret}" if mount_point else secret
try:
if version == 2:
response = client.secrets.kv.v2.read_secret_version(
secret, mount_point=mount_point, version=secret_version
)
else:
response = client.secrets.kv.v1.read_secret(secret, mount_point=mount_point)
except InvalidPath:
response = None
except Exception as e: # noqa: BLE001
result["rc"] = 1
result["failed"] = True
error_string = f"{e.__class__.__name__}({e})"
result["msg"] = f"Error {error_string} reading {secret_path}"
return result
if not response:
if default is not None:
result["value"] = default
return result
result["rc"] = 1
result["failed"] = True
result["msg"] = f"Secret {secret_path} is not in vault"
return result
if version == 2:
try:
data = response.get("data", {})
data = data.get("data", {})
except Exception: # noqa: BLE001
data = str(response)
else:
data = response["data"]
lease_duration = response.get("lease_duration", None)
if lease_duration is not None:
result["lease_duration"] = lease_duration
lease_id = response.get("lease_id", None)
if lease_id is not None:
result["lease_id"] = lease_id
renewable = response.get("renewable", None)
if renewable is not None:
result["renewable"] = renewable
wrap_info = response.get("wrap_info", None)
if wrap_info is not None:
result["wrap_info"] = wrap_info
if key and key not in data:
if default is not None:
result["value"] = default
return result
result["rc"] = 1
result["failed"] = True
result["msg"] = f"Key {key} is not in secret {secret_path}"
return result
value = data[key] if key else data
result["value"] = value
return result
class AppRoleClient:
"""
hvac.Client decorator generate and set a new approle token.
This allows multiple calls to Vault without having to manually
generate and set a token on every Vault call.
"""
def __init__(self, client, role_id, secret_id, mount_point):
object.__setattr__(self, "client", client)
object.__setattr__(self, "role_id", role_id)
object.__setattr__(self, "secret_id", secret_id)
object.__setattr__(self, "login_mount_point", mount_point)
def __setattr__(self, name, val):
client = object.__getattribute__(self, "client")
client.__setattr__(name, val)
def __getattribute__(self, name):
client = object.__getattribute__(self, "client")
attr = client.__getattribute__(name)
role_id = object.__getattribute__(self, "role_id")
secret_id = object.__getattribute__(self, "secret_id")
login_mount_point = object.__getattribute__(self, "login_mount_point")
resp = client.auth_approle(role_id, secret_id=secret_id, mount_point=login_mount_point)
client.token = str(resp["auth"]["client_token"])
return attr
def _compare_state(desired_state, current_state, ignore=None):
"""
Compare desired state to current state.
Returns true if objects are equal.
Recursively walks dict object to compare all keys.
:param desired_state: The state user desires.
:param current_state: The state that currently exists.
:param ignore: Ignore these keys.
:type ignore: list
:return: True if the states are the same.
:rtype: bool
"""
if ignore is None:
ignore = []
if (type(desired_state) is list):
if ((type(current_state) != list) or (len(desired_state) != len(current_state))):
return False
return set(desired_state) == set(current_state)
if (type(desired_state) is dict):
if (type(current_state) != dict):
return False
# iterate over dictionary keys
for key in desired_state:
if key in ignore:
continue
v = desired_state[key]
if ((key not in current_state) or (not _compare_state(v, current_state.get(key)))):
return False
return True
# Lots of things get handled as strings in ansible that aren"t necessarily strings,
# can extend this list later.
if isinstance(desired_state, str) and isinstance(current_state, int):
current_state = str(current_state)
return (desired_state == current_state)
def _convert_to_seconds(original_value):
try:
value = str(original_value)
seconds = 0
if "h" in value:
ray = value.split("h")
seconds = int(ray.pop(0)) * 3600
value = "".join(ray)
if "m" in value:
ray = value.split("m")
seconds += int(ray.pop(0)) * 60
value = "".join(ray)
if value:
ray = value.split("s")
seconds += int(ray.pop(0))
return seconds
except Exception: # noqa: BLE001
pass
return original_value
def get_keys_updated(desired_state, current_state, ignore=None):
"""
Return list of keys that have different values.
Recursively walks dict object to compare all keys.
:param desired_state: The state user desires.
:type desired_state: dict
:param current_state: The state that currently exists.
:type current_state: dict
:param ignore: Ignore these keys.
:type ignore: list
:return: Different items
:rtype: list
"""
if ignore is None:
ignore = []
differences = []
for key in desired_state:
if key in ignore:
continue
if (key not in current_state):
differences.append(key)
continue
new_value = desired_state[key]
old_value = current_state[key]
if "ttl" in key and (_convert_to_seconds(old_value) != _convert_to_seconds(new_value)):
differences.append(key)
elif not _compare_state(new_value, old_value):
differences.append(key)
return differences
def is_state_changed(desired_state, current_state, ignore=None): # noqa: ARG001
"""
Return list of keys that have different values.
Recursively walks dict object to compare all keys.
:param desired_state: The state user desires.
:type desired_state: dict
:param current_state: The state that currently exists.
:type current_state: dict
:param ignore: Ignore these keys.
:type ignore: list
:return: Different
:rtype: bool
"""
return (len(get_keys_updated(desired_state, current_state)) > 0)