Issues (229)

server/lib/rsa/key.py (1 issue)

1
# -*- coding: utf-8 -*-
2
#
3
#  Copyright 2011 Sybren A. Stüvel <[email protected]>
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#      http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16
17
'''RSA key generation code.
18
19
Create new keys with the newkeys() function. It will give you a PublicKey and a
20
PrivateKey object.
21
22
Loading and saving keys requires the pyasn1 module. This module is imported as
23
late as possible, such that other functionality will remain working in absence
24
of pyasn1.
25
26
'''
27
28
import logging
29
30
import rsa.prime
31
import rsa.pem
32
import rsa.common
33
34
log = logging.getLogger(__name__)
35
36
class AbstractKey(object):
37
    '''Abstract superclass for private and public keys.'''
38
39
    @classmethod
40
    def load_pkcs1(cls, keyfile, format='PEM'):
41
        r'''Loads a key in PKCS#1 DER or PEM format.
42
43
        :param keyfile: contents of a DER- or PEM-encoded file that contains
44
            the public key.
45
        :param format: the format of the file to load; 'PEM' or 'DER'
46
47
        :return: a PublicKey object
48
49
        '''
50
51
        methods = {
52
            'PEM': cls._load_pkcs1_pem,
53
            'DER': cls._load_pkcs1_der,
54
        }
55
56
        if format not in methods:
57
            formats = ', '.join(sorted(methods.keys()))
58
            raise ValueError('Unsupported format: %r, try one of %s' % (format,
59
                formats))
60
61
        method = methods[format]
62
        return method(keyfile)
63
64
    def save_pkcs1(self, format='PEM'):
65
        '''Saves the public key in PKCS#1 DER or PEM format.
66
67
        :param format: the format to save; 'PEM' or 'DER'
68
        :returns: the DER- or PEM-encoded public key.
69
70
        '''
71
72
        methods = {
73
            'PEM': self._save_pkcs1_pem,
74
            'DER': self._save_pkcs1_der,
75
        }
76
77
        if format not in methods:
78
            formats = ', '.join(sorted(methods.keys()))
79
            raise ValueError('Unsupported format: %r, try one of %s' % (format,
80
                formats))
81
82
        method = methods[format]
83
        return method()
84
85
class PublicKey(AbstractKey):
86
    '''Represents a public RSA key.
87
88
    This key is also known as the 'encryption key'. It contains the 'n' and 'e'
89
    values.
90
91
    Supports attributes as well as dictionary-like access. Attribute accesss is
92
    faster, though.
93
94
    >>> PublicKey(5, 3)
95
    PublicKey(5, 3)
96
97
    >>> key = PublicKey(5, 3)
98
    >>> key.n
99
    5
100
    >>> key['n']
101
    5
102
    >>> key.e
103
    3
104
    >>> key['e']
105
    3
106
107
    '''
108
109
    __slots__ = ('n', 'e')
110
111
    def __init__(self, n, e):
112
        self.n = n
113
        self.e = e
114
115
    def __getitem__(self, key):
116
        return getattr(self, key)
117
118
    def __repr__(self):
119
        return u'PublicKey(%i, %i)' % (self.n, self.e)
120
121
    def __eq__(self, other):
122
        if other is None:
123
            return False
124
125
        if not isinstance(other, PublicKey):
126
            return False
127
128
        return self.n == other.n and self.e == other.e
129
130
    def __ne__(self, other):
131
        return not (self == other)
132
133
    @classmethod
134
    def _load_pkcs1_der(cls, keyfile):
135
        r'''Loads a key in PKCS#1 DER format.
136
137
        @param keyfile: contents of a DER-encoded file that contains the public
138
            key.
139
        @return: a PublicKey object
140
141
        First let's construct a DER encoded key:
142
143
        >>> import base64
144
        >>> b64der = 'MAwCBQCNGmYtAgMBAAE='
145
        >>> der = base64.decodestring(b64der)
146
147
        This loads the file:
148
149
        >>> PublicKey._load_pkcs1_der(der)
150
        PublicKey(2367317549, 65537)
151
152
        '''
153
154
        from pyasn1.codec.der import decoder
155
        (priv, _) = decoder.decode(keyfile)
156
157
        # ASN.1 contents of DER encoded public key:
158
        #
159
        # RSAPublicKey ::= SEQUENCE {
160
        #     modulus           INTEGER,  -- n
161
        #     publicExponent    INTEGER,  -- e
162
163
        as_ints = tuple(int(x) for x in priv)
164
        return cls(*as_ints)
165
166
    def _save_pkcs1_der(self):
167
        '''Saves the public key in PKCS#1 DER format.
168
169
        @returns: the DER-encoded public key.
170
        '''
171
172
        from pyasn1.type import univ, namedtype
173
        from pyasn1.codec.der import encoder
174
175
        class AsnPubKey(univ.Sequence):
176
            componentType = namedtype.NamedTypes(
177
                namedtype.NamedType('modulus', univ.Integer()),
178
                namedtype.NamedType('publicExponent', univ.Integer()),
179
            )
180
181
        # Create the ASN object
182
        asn_key = AsnPubKey()
183
        asn_key.setComponentByName('modulus', self.n)
184
        asn_key.setComponentByName('publicExponent', self.e)
185
186
        return encoder.encode(asn_key)
187
188
    @classmethod
189
    def _load_pkcs1_pem(cls, keyfile):
190
        '''Loads a PKCS#1 PEM-encoded public key file.
191
192
        The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and
193
        after the "-----END RSA PUBLIC KEY-----" lines is ignored.
194
195
        @param keyfile: contents of a PEM-encoded file that contains the public
196
            key.
197
        @return: a PublicKey object
198
        '''
199
200
        der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY')
201
        return cls._load_pkcs1_der(der)
202
203
    def _save_pkcs1_pem(self):
204
        '''Saves a PKCS#1 PEM-encoded public key file.
205
206
        @return: contents of a PEM-encoded file that contains the public key.
207
        '''
208
209
        der = self._save_pkcs1_der()
210
        return rsa.pem.save_pem(der, 'RSA PUBLIC KEY')
211
212
class PrivateKey(AbstractKey):
213
    '''Represents a private RSA key.
214
215
    This key is also known as the 'decryption key'. It contains the 'n', 'e',
216
    'd', 'p', 'q' and other values.
217
218
    Supports attributes as well as dictionary-like access. Attribute accesss is
219
    faster, though.
220
221
    >>> PrivateKey(3247, 65537, 833, 191, 17)
222
    PrivateKey(3247, 65537, 833, 191, 17)
223
224
    exp1, exp2 and coef don't have to be given, they will be calculated:
225
226
    >>> pk = PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
227
    >>> pk.exp1
228
    55063
229
    >>> pk.exp2
230
    10095
231
    >>> pk.coef
232
    50797
233
234
    If you give exp1, exp2 or coef, they will be used as-is:
235
236
    >>> pk = PrivateKey(1, 2, 3, 4, 5, 6, 7, 8)
237
    >>> pk.exp1
238
    6
239
    >>> pk.exp2
240
    7
241
    >>> pk.coef
242
    8
243
244
    '''
245
246
    __slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef')
247
248
    def __init__(self, n, e, d, p, q, exp1=None, exp2=None, coef=None):
249
        self.n = n
250
        self.e = e
251
        self.d = d
252
        self.p = p
253
        self.q = q
254
255
        # Calculate the other values if they aren't supplied
256
        if exp1 is None:
257
            self.exp1 = int(d % (p - 1))
258
        else:
259
            self.exp1 = exp1
260
261
        if exp1 is None:
262
            self.exp2 = int(d % (q - 1))
263
        else:
264
            self.exp2 = exp2
265
266
        if coef is None:
267
            (_, self.coef, _) = extended_gcd(q, p)
268
        else:
269
            self.coef = coef
270
271
    def __getitem__(self, key):
272
        return getattr(self, key)
273
274
    def __repr__(self):
275
        return u'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self
276
277
    def __eq__(self, other):
278
        if other is None:
279
            return False
280
281
        if not isinstance(other, PrivateKey):
282
            return False
283
284
        return (self.n == other.n and
285
            self.e == other.e and
286
            self.d == other.d and
287
            self.p == other.p and
288
            self.q == other.q and
289
            self.exp1 == other.exp1 and
290
            self.exp2 == other.exp2 and
291
            self.coef == other.coef)
292
293
    def __ne__(self, other):
294
        return not (self == other)
295
296
    @classmethod
297
    def _load_pkcs1_der(cls, keyfile):
298
        r'''Loads a key in PKCS#1 DER format.
299
300
        @param keyfile: contents of a DER-encoded file that contains the private
301
            key.
302
        @return: a PrivateKey object
303
304
        First let's construct a DER encoded key:
305
306
        >>> import base64
307
        >>> b64der = 'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt'
308
        >>> der = base64.decodestring(b64der)
309
310
        This loads the file:
311
312
        >>> PrivateKey._load_pkcs1_der(der)
313
        PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
314
315
        '''
316
317
        from pyasn1.codec.der import decoder
318
        (priv, _) = decoder.decode(keyfile)
319
320
        # ASN.1 contents of DER encoded private key:
321
        #
322
        # RSAPrivateKey ::= SEQUENCE {
323
        #     version           Version,
324
        #     modulus           INTEGER,  -- n
325
        #     publicExponent    INTEGER,  -- e
326
        #     privateExponent   INTEGER,  -- d
327
        #     prime1            INTEGER,  -- p
328
        #     prime2            INTEGER,  -- q
329
        #     exponent1         INTEGER,  -- d mod (p-1)
330
        #     exponent2         INTEGER,  -- d mod (q-1)
331
        #     coefficient       INTEGER,  -- (inverse of q) mod p
332
        #     otherPrimeInfos   OtherPrimeInfos OPTIONAL
333
        # }
334
335
        if priv[0] != 0:
336
            raise ValueError('Unable to read this file, version %s != 0' % priv[0])
337
338
        as_ints = tuple(int(x) for x in priv[1:9])
339
        return cls(*as_ints)
340
341
    def _save_pkcs1_der(self):
342
        '''Saves the private key in PKCS#1 DER format.
343
344
        @returns: the DER-encoded private key.
345
        '''
346
347
        from pyasn1.type import univ, namedtype
348
        from pyasn1.codec.der import encoder
349
350
        class AsnPrivKey(univ.Sequence):
351
            componentType = namedtype.NamedTypes(
352
                namedtype.NamedType('version', univ.Integer()),
353
                namedtype.NamedType('modulus', univ.Integer()),
354
                namedtype.NamedType('publicExponent', univ.Integer()),
355
                namedtype.NamedType('privateExponent', univ.Integer()),
356
                namedtype.NamedType('prime1', univ.Integer()),
357
                namedtype.NamedType('prime2', univ.Integer()),
358
                namedtype.NamedType('exponent1', univ.Integer()),
359
                namedtype.NamedType('exponent2', univ.Integer()),
360
                namedtype.NamedType('coefficient', univ.Integer()),
361
            )
362
363
        # Create the ASN object
364
        asn_key = AsnPrivKey()
365
        asn_key.setComponentByName('version', 0)
366
        asn_key.setComponentByName('modulus', self.n)
367
        asn_key.setComponentByName('publicExponent', self.e)
368
        asn_key.setComponentByName('privateExponent', self.d)
369
        asn_key.setComponentByName('prime1', self.p)
370
        asn_key.setComponentByName('prime2', self.q)
371
        asn_key.setComponentByName('exponent1', self.exp1)
372
        asn_key.setComponentByName('exponent2', self.exp2)
373
        asn_key.setComponentByName('coefficient', self.coef)
374
375
        return encoder.encode(asn_key)
376
377
    @classmethod
378
    def _load_pkcs1_pem(cls, keyfile):
379
        '''Loads a PKCS#1 PEM-encoded private key file.
380
381
        The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and
382
        after the "-----END RSA PRIVATE KEY-----" lines is ignored.
383
384
        @param keyfile: contents of a PEM-encoded file that contains the private
385
            key.
386
        @return: a PrivateKey object
387
        '''
388
389
        der = rsa.pem.load_pem(keyfile, 'RSA PRIVATE KEY')
390
        return cls._load_pkcs1_der(der)
391
392
    def _save_pkcs1_pem(self):
393
        '''Saves a PKCS#1 PEM-encoded private key file.
394
395
        @return: contents of a PEM-encoded file that contains the private key.
396
        '''
397
398
        der = self._save_pkcs1_der()
399
        return rsa.pem.save_pem(der, 'RSA PRIVATE KEY')
400
401
402 View Code Duplication
def extended_gcd(a, b):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
403
    """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
404
    """
405
    # r = gcd(a,b) i = multiplicitive inverse of a mod b
406
    #      or      j = multiplicitive inverse of b mod a
407
    # Neg return values for i or j are made positive mod b or a respectively
408
    # Iterateive Version is faster and uses much less stack space
409
    x = 0
410
    y = 1
411
    lx = 1
412
    ly = 0
413
    oa = a                             #Remember original a/b to remove
414
    ob = b                             #negative values from return results
415
    while b != 0:
416
        q = a // b
417
        (a, b)  = (b, a % b)
418
        (x, lx) = ((lx - (q * x)),x)
419
        (y, ly) = ((ly - (q * y)),y)
420
    if (lx < 0): lx += ob              #If neg wrap modulo orignal b
421
    if (ly < 0): ly += oa              #If neg wrap modulo orignal a
422
    return (a, lx, ly)                 #Return only positive values
423
424
def find_p_q(nbits, accurate=True):
425
    ''''Returns a tuple of two different primes of nbits bits each.
426
427
    The resulting p * q has exacty 2 * nbits bits, and the returned p and q
428
    will not be equal.
429
430
    @param nbits: the number of bits in each of p and q.
431
    @param accurate: whether to enable accurate mode or not.
432
    @returns (p, q), where p > q
433
434
    >>> (p, q) = find_p_q(128)
435
    >>> from rsa import common
436
    >>> common.bit_size(p * q)
437
    256
438
439
    When not in accurate mode, the number of bits can be slightly less
440
441
    >>> (p, q) = find_p_q(128, accurate=False)
442
    >>> from rsa import common
443
    >>> common.bit_size(p * q) <= 256
444
    True
445
    >>> common.bit_size(p * q) > 240
446
    True
447
448
    '''
449
450
    total_bits = nbits * 2
451
452
    # Make sure that p and q aren't too close or the factoring programs can
453
    # factor n.
454
    shift = nbits // 16
455
    pbits = nbits + shift
456
    qbits = nbits - shift
457
458
    # Choose the two initial primes
459
    log.debug('find_p_q(%i): Finding p', nbits)
460
    p = rsa.prime.getprime(pbits)
461
    log.debug('find_p_q(%i): Finding q', nbits)
462
    q = rsa.prime.getprime(qbits)
463
464
    def is_acceptable(p, q):
465
        '''Returns True iff p and q are acceptable:
466
467
            - p and q differ
468
            - (p * q) has the right nr of bits (when accurate=True)
469
        '''
470
471
        if p == q:
472
            return False
473
474
        if not accurate:
475
            return True
476
477
        # Make sure we have just the right amount of bits
478
        found_size = rsa.common.bit_size(p * q)
479
        return total_bits == found_size
480
481
    # Keep choosing other primes until they match our requirements.
482
    change_p = False
483
    tries = 0
484
    while not is_acceptable(p, q):
485
        tries += 1
486
        # Change p on one iteration and q on the other
487
        if change_p:
488
            log.debug('   find another p')
489
            p = rsa.prime.getprime(pbits)
490
        else:
491
            log.debug('   find another q')
492
            q = rsa.prime.getprime(qbits)
493
494
        change_p = not change_p
495
496
    # We want p > q as described on
497
    # http://www.di-mgt.com.au/rsa_alg.html#crt
498
    return (max(p, q), min(p, q))
499
500
def calculate_keys(p, q, nbits):
501
    """Calculates an encryption and a decryption key given p and q, and
502
    returns them as a tuple (e, d)
503
504
    """
505
506
    phi_n = (p - 1) * (q - 1)
507
508
    # A very common choice for e is 65537
509
    e = 65537
510
511
    (divider, d, _) = extended_gcd(e, phi_n)
512
513
    if divider != 1:
514
        raise ValueError("e (%d) and phi_n (%d) are not relatively prime" %
515
                (e, phi_n))
516
    if (d < 0):
517
        raise ValueError("extended_gcd shouldn't return negative values, "
518
                "please file a bug")
519
    if (e * d) % phi_n != 1:
520
        raise ValueError("e (%d) and d (%d) are not mult. inv. modulo "
521
                "phi_n (%d)" % (e, d, phi_n))
522
523
    return (e, d)
524
525
def gen_keys(nbits, accurate=True):
526
    """Generate RSA keys of nbits bits. Returns (p, q, e, d).
527
528
    Note: this can take a long time, depending on the key size.
529
530
    @param nbits: the total number of bits in ``p`` and ``q``. Both ``p`` and
531
        ``q`` will use ``nbits/2`` bits.
532
    """
533
534
    (p, q) = find_p_q(nbits // 2, accurate)
535
    (e, d) = calculate_keys(p, q, nbits // 2)
536
537
    return (p, q, e, d)
538
539
def newkeys(nbits, accurate=True):
540
    """Generates public and private keys, and returns them as (pub, priv).
541
542
    The public key is also known as the 'encryption key', and is a
543
    :py:class:`PublicKey` object. The private key is also known as the
544
    'decryption key' and is a :py:class:`PrivateKey` object.
545
546
    :param nbits: the number of bits required to store ``n = p*q``.
547
    :param accurate: when True, ``n`` will have exactly the number of bits you
548
        asked for. However, this makes key generation much slower. When False,
549
        `n`` may have slightly less bits.
550
551
    :returns: a tuple (:py:class:`rsa.PublicKey`, :py:class:`rsa.PrivateKey`)
552
553
    """
554
555
    if nbits < 16:
556
        raise ValueError('Key too small')
557
558
    (p, q, e, d) = gen_keys(nbits)
559
560
    n = p * q
561
562
    return (
563
        PublicKey(n, e),
564
        PrivateKey(n, e, d, p, q)
565
    )
566
567
__all__ = ['PublicKey', 'PrivateKey', 'newkeys']
568
569
if __name__ == '__main__':
570
    import doctest
571
572
    try:
573
        for count in range(100):
574
            (failures, tests) = doctest.testmod()
575
            if failures:
576
                break
577
578
            if (count and count % 10 == 0) or count == 1:
579
                print '%i times' % count
580
    except KeyboardInterrupt:
581
        print 'Aborted'
582
    else:
583
        print 'Doctests done'
584