#!/usr/bin/env python3
"""This script is designed to generate self-signed TLS (Transport Layer
Security) certificates for encrypting sensitive information during
transmission.
TLS certificates are essential for ensuring secure communication and
preventing unauthorized access to data transmitted over the network.
"""

import argparse
import datetime
import os
import socket
import sys
from typing import Optional
from uuid import uuid4

from OpenSSL import crypto
from OpenSSL.crypto import PKey, X509

CERT_LIFETIME = 365 * 24 * 60 * 60 * 10  # Valid for 10 years
KEY_LENGTH = 4096
MY_HOSTNAME = socket.gethostname().lower()
MY_AVAHINAME = MY_HOSTNAME.split(".")[0] + ".local"  # AVAHI mDNS hostname
PATH_CERT = "/etc/ssl/certs/revpi-self-signed.pem"
PATH_KEY = "/etc/ssl/private/revpi-self-signed.key"

parser = argparse.ArgumentParser(
    description=__doc__,
    epilog="Default behaviour: no arguments. Check if certificate and key exist and if not, create the certificate and key.",
)

parser.add_argument(
    "-n",
    "--name",
    dest="name",
    action="store_true",
    default=False,
    help="create TLS-Certificate only when hostname not correct",
)
parser.add_argument(
    "-t",
    "--time",
    dest="time",
    action="store_true",
    default=False,
    help="create TLS-Certificate only when date is expired",
)
parser.add_argument(
    "-f",
    "--force",
    dest="force",
    action="store_true",
    default=False,
    help="create new TLS-Certificate overwriting any existing certificate",
)

args = parser.parse_args()


def check_cert_key() -> bool:
    """
    Checks whether the key file fits the certificate file.

    The key file must match the certificate file. The certificate was signed
    with the private key.

    :return: True, if certificate and key matches
    """
    cert = read_certificate(PATH_CERT)
    if cert is None:
        return False
    pkey_certificate = cert.get_pubkey()
    pkey_private_key = read_privatekey(PATH_KEY)

    if pkey_private_key is None:
        return False

    # Dump resulting public keys from certificate and private key
    dump_certificate = crypto.dump_publickey(crypto.FILETYPE_PEM, pkey_certificate)
    dump_private_key = crypto.dump_publickey(crypto.FILETYPE_PEM, pkey_private_key)

    # The dumps must match if the certificate is signed with the key
    return dump_certificate == dump_private_key


def create_certificate() -> bool:
    """
    Create and write a self-signed certificate and key into pem files.

    :return: True, if certificate was successfully created and saved
    """
    print("Generate certificate... ", end="")

    # Generate new certificate
    cert = crypto.X509()
    # Version is zero-based therefore a value of 2 means v3
    cert.set_version(2)
    cert.get_subject().CN = MY_HOSTNAME
    cert.get_subject().OU = "Revolution Pi"
    cert.set_serial_number(uuid4().int)  # Generate a random unique serial
    cert.gmtime_adj_notBefore(0)  # 0 sets the start date to the current time
    cert.gmtime_adj_notAfter(CERT_LIFETIME)

    # Set ourselves as the issuer - This is self-signed
    cert.set_issuer(cert.get_subject())

    subject_alt_name = f"DNS:{MY_HOSTNAME},DNS:{MY_AVAHINAME}".encode("utf-8")
    cert.add_extensions(
        [
            # Use subject alternative name to bind the DNS and additionally the mDNS hostname for avahi (.local)
            crypto.X509Extension(b"subjectAltName", False, subject_alt_name),
            # An end user certificate must either set CA to FALSE or exclude the extension entirely.
            # Some software may require the inclusion of basicConstraints with CA set to FALSE for end entity certificates.
            crypto.X509Extension(b"basicConstraints", False, b"CA:FALSE"),
        ]
    )

    # Generate a key pair to self-sign certificate
    pkey = crypto.PKey()
    pkey.generate_key(crypto.TYPE_RSA, KEY_LENGTH)
    cert.set_pubkey(pkey)
    cert.sign(pkey, "sha256")

    # Dump private key and certificate to a write buffer before manipulating files
    try:
        write_buffer_private_key = crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey)
        write_buffer_certificate = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
    except crypto.Error as e:
        print(f"\nCould not create certificate-key-pair: {e}")
        return False

    # Save certificate and private key to files
    try:
        # First write key file, the folder usually requires higher permissions.
        with open(PATH_KEY, "wb") as f:
            f.write(write_buffer_private_key)
    except OSError as e:
        print(f"\nCould not write private key file '{PATH_KEY}': {e}")
        return False
    try:
        with open(PATH_CERT, "wb") as f:
            f.write(write_buffer_certificate)
    except OSError as e:
        print(f"\nCould not write certificate file '{PATH_CERT}': {e}")

        # Delete key file, which we created in the step before
        os.remove(PATH_KEY)

        return False

    print(" done.")
    return True


def read_certificate(path: str) -> Optional[X509]:
    """
    Read PEM file.

    :param path: Path to PEM file
    :return: Certificate Object
    """
    try:
        with open(path, "rb") as f:
            x509_cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
    except OSError as e:
        print(f"Read error. Unable to read certificate file '{path}': {e}")
        return None
    except crypto.Error as e:
        print(f"Load error. Unable to load certificate from file: {e}")
        return None

    return x509_cert


def read_privatekey(path: str) -> Optional[PKey]:
    """
    Read PEM file of private key.

    :param path: Path to PEM file
    :return: Private key Object
    """
    try:
        with open(path, "rb") as f:
            pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
    except OSError as e:
        print(f"Read error. Unable to read private key file '{path}': {e}")
        return None
    except crypto.Error as e:
        print(f"Load error. Unable to load private key from file: {e}")
        return None

    return pkey


def main() -> int:
    if not os.path.exists(PATH_CERT):
        print("Certificate did not exist...")
        return int(not create_certificate())

    if not os.path.exists(PATH_KEY):
        print("Key did not exist...")
        return int(not create_certificate())

    if not check_cert_key():
        print("Keyfile does not fit certificate file.")
        return int(not create_certificate())

    if args.force:
        return int(not create_certificate())

    if args.time:
        datetime_now = datetime.datetime.now()
        date_format, encoding = "%Y%m%d%H%M%SZ", "ascii"  # ASN.1 time format
        cert = read_certificate(PATH_CERT)
        if not cert:
            # Could not read certificate file
            return 1
        not_before_opt = cert.get_notBefore()
        if not_before_opt is None:
            return 1
        not_before = datetime.datetime.strptime(
            not_before_opt.decode(encoding),
            date_format,
        )
        not_after_opt = cert.get_notAfter()
        if not_after_opt is None:
            return 1
        not_after = datetime.datetime.strptime(
            not_after_opt.decode(encoding),
            date_format,
        )
        if datetime_now > not_after - datetime.timedelta(days=8):
            print(f"The certificate provided expired on {not_after}.")
            return int(not create_certificate())

        else:
            print("Information:")
            print(f"The certificate provided is valid from {not_before}.")
            print(f"The certificate provided expires on {not_after}.\n")
    if args.name:
        cert = read_certificate(PATH_CERT)
        if not cert:
            # Could not read certificate file
            return 1

        expected_san = [MY_HOSTNAME, MY_AVAHINAME]
        missing_san = []

        for x in range(cert.get_extension_count()):
            extension = cert.get_extension(x)

            if extension.get_short_name() == b"subjectAltName":
                # A certificates SAN can have multiple entries, so lets split them
                # Remove all spaces and make values lower-case in order so simplify validation
                values = str(extension).lower().strip().split(", ")

                for san in expected_san:
                    # Check if expected SANs exist (keep in mind that SAN values are already lower-case)
                    if f"DNS:{san}".lower() not in values:
                        missing_san.append(san)

        if missing_san:
            print(
                "Generating new certificate: Missing SAN entries: "
                + ", ".join(missing_san)
            )
            return int(not create_certificate())

    print(f"Certificate not modified. To get more info {parser.prog} -h")

    return 0


if __name__ == "__main__":
    sys.exit(main())
