Passed
Push — main ( e15b4a...1c3d2f )
by Sat CFDI
05:36
created

satcfdi.models.signer.Signer.load_pkcs12()   A

Complexity

Conditions 3

Size

Total Lines 13
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 7
CRAP Score 3.0175

Importance

Changes 0
Metric Value
cc 3
eloc 11
nop 3
dl 0
loc 13
ccs 7
cts 8
cp 0.875
crap 3.0175
rs 9.85
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2 1
import base64
3 1
from typing import Literal
4
5 1
from OpenSSL import crypto
6 1
from OpenSSL.crypto import X509
7 1
from cryptography import x509
8 1
from cryptography.hazmat.primitives import hashes
9 1
from cryptography.hazmat.primitives.asymmetric import padding, rsa
10 1
from cryptography.hazmat.primitives.serialization import Encoding, pkcs12, PrivateFormat, BestAvailableEncryption, NoEncryption, PublicFormat, load_der_private_key
11
12 1
from .certificate import Certificate
13 1
from ..exceptions import CFDIError
14
15
16 1
class Signer(Certificate):
17 1
    def __init__(self, certificate: X509, key: rsa.RSAPrivateKey, check=True):
18 1
        super().__init__(certificate)
19 1
        self.key = key
20
21 1
        if check:
22 1
            res = _compare_public_keys(self.key.public_key(), self.public_key())
23 1
            if not res:
24 1
                raise CFDIError("Private Key does not match certificate")
25
26 1
    @classmethod
27 1
    def load(cls, certificate: bytes, key: bytes, password: str | bytes = None, check=True) -> 'Signer':
28 1
        if isinstance(password, str):
29
            password = password.encode()
30
31 1
        return cls(
32
            # certificate=x509.load_der_x509_certificate(certificate),
33
            certificate=crypto.load_certificate(crypto.FILETYPE_ASN1, certificate),
34
            key=load_der_private_key(
35
                data=key,
36
                password=password
37
            ),
38
            check=check
39
        )
40
41 1
    @classmethod
42 1
    def load_pkcs12(cls, data: bytes, password: str | bytes = None) -> 'Signer':
43 1
        if isinstance(password, str):
44 1
            password = password.encode()
45
46 1
        key, certificate, _ = pkcs12.load_key_and_certificates(data=data, password=password)
47 1
        if certificate is None:
48
            raise CFDIError("Certificate is missing")
49
50 1
        return cls(
51
            certificate=crypto.X509.from_cryptography(certificate),
52
            key=key,
53
            check=False  # pcks12 allready checks
54
        )
55
56 1
    def _sign(self, data, algorithm) -> str:
57 1
        signature = self.key.sign(
58
            data=data,
59
            padding=padding.PKCS1v15(),
60
            algorithm=algorithm
61
        )
62
63 1
        return base64.b64encode(
64
            signature
65
        ).decode()
66
67 1
    def sign_sha1(self, data) -> str:
68 1
        return self._sign(
69
            data=data,
70
            algorithm=hashes.SHA1()
71
        )
72
73 1
    def sign_sha256(self, data) -> str:
74 1
        return self._sign(
75
            data=data,
76
            algorithm=hashes.SHA256()
77
        )
78
79 1
    def key_bytes(
80
        self, password: str | bytes = None, encoding: Encoding = Encoding.DER
81
    ) -> bytes:
82
        """Returns the private key in bytes
83
84
        Args:
85
            password (str | bytes, optional): The password to decrypt the private key. Defaults to None.
86
            encoding (cryptography.hazmat.primitives.serialization.Encoding, optional): The encoding format of the private key. Defaults to "DER".
87
88
        Raises:
89
            ValueError: If the encoding is not "DER" or "PEM"
90
91
        Returns:
92
            bytes: The private key in bytes
93
        """
94 1
        if isinstance(password, str):
95 1
            password = password.encode()
96
97 1
        return self.key.private_bytes(
98
            encoding=encoding,
99
            format=PrivateFormat.PKCS8,
100
            encryption_algorithm=(
101
                BestAvailableEncryption(password) if password else NoEncryption()
102
            ),
103
        )
104
105 1
    def pcks12_bytes(self, password: str | bytes = None) -> bytes:
106 1
        if isinstance(password, str):
107 1
            password = password.encode()
108
109 1
        return pkcs12.serialize_key_and_certificates(
110
            name=self.rfc.encode(),
111
            key=self.key,
112
            cert=self.certificate.to_cryptography(),
113
            cas=None,
114
            encryption_algorithm=BestAvailableEncryption(password) if password else NoEncryption()
115
        )
116
117 1
    def decrypt(self, data: bytes):
118 1
        return self.key.decrypt(
119
            ciphertext=data,
120
            padding=padding.PKCS1v15()
121
        )
122
123
124 1
def _compare_public_keys(public_key_a, public_key_b):
125 1
    def key_bytes(k):
126 1
        return k.public_bytes(
127
            encoding=Encoding.DER,
128
            format=PublicFormat.SubjectPublicKeyInfo
129
        )
130
131
    return key_bytes(public_key_a) == key_bytes(public_key_b)
132