mirror of
https://github.com/thegeeklab/ansible-later.git
synced 2024-11-23 05:10:40 +00:00
309 lines
10 KiB
Python
309 lines
10 KiB
Python
# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
|
|
#
|
|
# This file is part of paramiko.
|
|
#
|
|
# Paramiko is free software; you can redistribute it and/or modify it under the
|
|
# terms of the GNU Lesser General Public License as published by the Free
|
|
# Software Foundation; either version 2.1 of the License, or (at your option)
|
|
# any later version.
|
|
#
|
|
# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
|
|
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
|
|
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
|
|
# details.
|
|
#
|
|
# You should have received a copy of the GNU Lesser General Public License
|
|
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
|
|
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
|
|
|
|
"""
|
|
ECDSA keys
|
|
"""
|
|
|
|
from cryptography.exceptions import InvalidSignature
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from cryptography.hazmat.primitives.asymmetric.utils import (
|
|
decode_dss_signature,
|
|
encode_dss_signature,
|
|
)
|
|
|
|
from paramiko.common import four_byte
|
|
from paramiko.message import Message
|
|
from paramiko.pkey import PKey
|
|
from paramiko.ssh_exception import SSHException
|
|
from paramiko.util import deflate_long
|
|
|
|
|
|
class _ECDSACurve(object):
|
|
"""
|
|
Represents a specific ECDSA Curve (nistp256, nistp384, etc).
|
|
|
|
Handles the generation of the key format identifier and the selection of
|
|
the proper hash function. Also grabs the proper curve from the 'ecdsa'
|
|
package.
|
|
"""
|
|
|
|
def __init__(self, curve_class, nist_name):
|
|
self.nist_name = nist_name
|
|
self.key_length = curve_class.key_size
|
|
|
|
# Defined in RFC 5656 6.2
|
|
self.key_format_identifier = "ecdsa-sha2-" + self.nist_name
|
|
|
|
# Defined in RFC 5656 6.2.1
|
|
if self.key_length <= 256:
|
|
self.hash_object = hashes.SHA256
|
|
elif self.key_length <= 384:
|
|
self.hash_object = hashes.SHA384
|
|
else:
|
|
self.hash_object = hashes.SHA512
|
|
|
|
self.curve_class = curve_class
|
|
|
|
|
|
class _ECDSACurveSet(object):
|
|
"""
|
|
A collection to hold the ECDSA curves. Allows querying by oid and by key
|
|
format identifier. The two ways in which ECDSAKey needs to be able to look
|
|
up curves.
|
|
"""
|
|
|
|
def __init__(self, ecdsa_curves):
|
|
self.ecdsa_curves = ecdsa_curves
|
|
|
|
def get_key_format_identifier_list(self):
|
|
return [curve.key_format_identifier for curve in self.ecdsa_curves]
|
|
|
|
def get_by_curve_class(self, curve_class):
|
|
for curve in self.ecdsa_curves:
|
|
if curve.curve_class == curve_class:
|
|
return curve
|
|
|
|
def get_by_key_format_identifier(self, key_format_identifier):
|
|
for curve in self.ecdsa_curves:
|
|
if curve.key_format_identifier == key_format_identifier:
|
|
return curve
|
|
|
|
def get_by_key_length(self, key_length):
|
|
for curve in self.ecdsa_curves:
|
|
if curve.key_length == key_length:
|
|
return curve
|
|
|
|
|
|
class ECDSAKey(PKey):
|
|
"""
|
|
Representation of an ECDSA key which can be used to sign and verify SSH2
|
|
data.
|
|
"""
|
|
|
|
_ECDSA_CURVES = _ECDSACurveSet(
|
|
[
|
|
_ECDSACurve(ec.SECP256R1, "nistp256"),
|
|
_ECDSACurve(ec.SECP384R1, "nistp384"),
|
|
_ECDSACurve(ec.SECP521R1, "nistp521"),
|
|
]
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
msg=None,
|
|
data=None,
|
|
filename=None,
|
|
password=None,
|
|
vals=None,
|
|
file_obj=None,
|
|
validate_point=True,
|
|
):
|
|
self.verifying_key = None
|
|
self.signing_key = None
|
|
self.public_blob = None
|
|
if file_obj is not None:
|
|
self._from_private_key(file_obj, password)
|
|
return
|
|
if filename is not None:
|
|
self._from_private_key_file(filename, password)
|
|
return
|
|
if (msg is None) and (data is not None):
|
|
msg = Message(data)
|
|
if vals is not None:
|
|
self.signing_key, self.verifying_key = vals
|
|
c_class = self.signing_key.curve.__class__
|
|
self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(c_class)
|
|
else:
|
|
# Must set ecdsa_curve first; subroutines called herein may need to
|
|
# spit out our get_name(), which relies on this.
|
|
key_type = msg.get_text()
|
|
# But this also means we need to hand it a real key/curve
|
|
# identifier, so strip out any cert business. (NOTE: could push
|
|
# that into _ECDSACurveSet.get_by_key_format_identifier(), but it
|
|
# feels more correct to do it here?)
|
|
suffix = "-cert-v01@openssh.com"
|
|
if key_type.endswith(suffix):
|
|
key_type = key_type[: -len(suffix)]
|
|
self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier(
|
|
key_type
|
|
)
|
|
key_types = self._ECDSA_CURVES.get_key_format_identifier_list()
|
|
cert_types = [
|
|
"{}-cert-v01@openssh.com".format(x) for x in key_types
|
|
]
|
|
self._check_type_and_load_cert(
|
|
msg=msg, key_type=key_types, cert_type=cert_types
|
|
)
|
|
curvename = msg.get_text()
|
|
if curvename != self.ecdsa_curve.nist_name:
|
|
raise SSHException(
|
|
"Can't handle curve of type {}".format(curvename)
|
|
)
|
|
|
|
pointinfo = msg.get_binary()
|
|
try:
|
|
numbers = ec.EllipticCurvePublicNumbers.from_encoded_point(
|
|
self.ecdsa_curve.curve_class(), pointinfo
|
|
)
|
|
except ValueError:
|
|
raise SSHException("Invalid public key")
|
|
self.verifying_key = numbers.public_key(backend=default_backend())
|
|
|
|
@classmethod
|
|
def supported_key_format_identifiers(cls):
|
|
return cls._ECDSA_CURVES.get_key_format_identifier_list()
|
|
|
|
def asbytes(self):
|
|
key = self.verifying_key
|
|
m = Message()
|
|
m.add_string(self.ecdsa_curve.key_format_identifier)
|
|
m.add_string(self.ecdsa_curve.nist_name)
|
|
|
|
numbers = key.public_numbers()
|
|
|
|
key_size_bytes = (key.curve.key_size + 7) // 8
|
|
|
|
x_bytes = deflate_long(numbers.x, add_sign_padding=False)
|
|
x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes
|
|
|
|
y_bytes = deflate_long(numbers.y, add_sign_padding=False)
|
|
y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes
|
|
|
|
point_str = four_byte + x_bytes + y_bytes
|
|
m.add_string(point_str)
|
|
return m.asbytes()
|
|
|
|
def __str__(self):
|
|
return self.asbytes()
|
|
|
|
def __hash__(self):
|
|
return hash(
|
|
(
|
|
self.get_name(),
|
|
self.verifying_key.public_numbers().x,
|
|
self.verifying_key.public_numbers().y,
|
|
)
|
|
)
|
|
|
|
def get_name(self):
|
|
return self.ecdsa_curve.key_format_identifier
|
|
|
|
def get_bits(self):
|
|
return self.ecdsa_curve.key_length
|
|
|
|
def can_sign(self):
|
|
return self.signing_key is not None
|
|
|
|
def sign_ssh_data(self, data):
|
|
ecdsa = ec.ECDSA(self.ecdsa_curve.hash_object())
|
|
sig = self.signing_key.sign(data, ecdsa)
|
|
r, s = decode_dss_signature(sig)
|
|
|
|
m = Message()
|
|
m.add_string(self.ecdsa_curve.key_format_identifier)
|
|
m.add_string(self._sigencode(r, s))
|
|
return m
|
|
|
|
def verify_ssh_sig(self, data, msg):
|
|
if msg.get_text() != self.ecdsa_curve.key_format_identifier:
|
|
return False
|
|
sig = msg.get_binary()
|
|
sigR, sigS = self._sigdecode(sig)
|
|
signature = encode_dss_signature(sigR, sigS)
|
|
|
|
try:
|
|
self.verifying_key.verify(
|
|
signature, data, ec.ECDSA(self.ecdsa_curve.hash_object())
|
|
)
|
|
except InvalidSignature:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def write_private_key_file(self, filename, password=None):
|
|
self._write_private_key_file(
|
|
filename,
|
|
self.signing_key,
|
|
serialization.PrivateFormat.TraditionalOpenSSL,
|
|
password=password,
|
|
)
|
|
|
|
def write_private_key(self, file_obj, password=None):
|
|
self._write_private_key(
|
|
file_obj,
|
|
self.signing_key,
|
|
serialization.PrivateFormat.TraditionalOpenSSL,
|
|
password=password,
|
|
)
|
|
|
|
@classmethod
|
|
def generate(cls, curve=ec.SECP256R1(), progress_func=None, bits=None):
|
|
"""
|
|
Generate a new private ECDSA key. This factory function can be used to
|
|
generate a new host key or authentication key.
|
|
|
|
:param progress_func: Not used for this type of key.
|
|
:returns: A new private key (`.ECDSAKey`) object
|
|
"""
|
|
if bits is not None:
|
|
curve = cls._ECDSA_CURVES.get_by_key_length(bits)
|
|
if curve is None:
|
|
raise ValueError("Unsupported key length: {:d}".format(bits))
|
|
curve = curve.curve_class()
|
|
|
|
private_key = ec.generate_private_key(curve, backend=default_backend())
|
|
return ECDSAKey(vals=(private_key, private_key.public_key()))
|
|
|
|
# ...internals...
|
|
|
|
def _from_private_key_file(self, filename, password):
|
|
data = self._read_private_key_file("EC", filename, password)
|
|
self._decode_key(data)
|
|
|
|
def _from_private_key(self, file_obj, password):
|
|
data = self._read_private_key("EC", file_obj, password)
|
|
self._decode_key(data)
|
|
|
|
def _decode_key(self, data):
|
|
try:
|
|
key = serialization.load_der_private_key(
|
|
data, password=None, backend=default_backend()
|
|
)
|
|
except (ValueError, AssertionError) as e:
|
|
raise SSHException(str(e))
|
|
|
|
self.signing_key = key
|
|
self.verifying_key = key.public_key()
|
|
curve_class = key.curve.__class__
|
|
self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(curve_class)
|
|
|
|
def _sigencode(self, r, s):
|
|
msg = Message()
|
|
msg.add_mpint(r)
|
|
msg.add_mpint(s)
|
|
return msg.asbytes()
|
|
|
|
def _sigdecode(self, sig):
|
|
msg = Message(sig)
|
|
r = msg.get_mpint()
|
|
s = msg.get_mpint()
|
|
return r, s
|