1 | # -*- coding: utf-8 -*- |
||
2 | # |
||
3 | # Copyright 2011 Sybren A. Stüvel <[email protected]> |
||
4 | # |
||
5 | # Licensed under the Apache License, Version 2.0 (the "License"); |
||
6 | # you may not use this file except in compliance with the License. |
||
7 | # You may obtain a copy of the License at |
||
8 | # |
||
9 | # http://www.apache.org/licenses/LICENSE-2.0 |
||
10 | # |
||
11 | # Unless required by applicable law or agreed to in writing, software |
||
12 | # distributed under the License is distributed on an "AS IS" BASIS, |
||
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||
14 | # See the License for the specific language governing permissions and |
||
15 | # limitations under the License. |
||
16 | |||
17 | '''RSA key generation code. |
||
18 | |||
19 | Create new keys with the newkeys() function. It will give you a PublicKey and a |
||
20 | PrivateKey object. |
||
21 | |||
22 | Loading and saving keys requires the pyasn1 module. This module is imported as |
||
23 | late as possible, such that other functionality will remain working in absence |
||
24 | of pyasn1. |
||
25 | |||
26 | ''' |
||
27 | |||
28 | import logging |
||
29 | |||
30 | import rsa.prime |
||
31 | import rsa.pem |
||
32 | import rsa.common |
||
33 | |||
34 | log = logging.getLogger(__name__) |
||
35 | |||
36 | class AbstractKey(object): |
||
37 | '''Abstract superclass for private and public keys.''' |
||
38 | |||
39 | @classmethod |
||
40 | def load_pkcs1(cls, keyfile, format='PEM'): |
||
41 | r'''Loads a key in PKCS#1 DER or PEM format. |
||
42 | |||
43 | :param keyfile: contents of a DER- or PEM-encoded file that contains |
||
44 | the public key. |
||
45 | :param format: the format of the file to load; 'PEM' or 'DER' |
||
46 | |||
47 | :return: a PublicKey object |
||
48 | |||
49 | ''' |
||
50 | |||
51 | methods = { |
||
52 | 'PEM': cls._load_pkcs1_pem, |
||
53 | 'DER': cls._load_pkcs1_der, |
||
54 | } |
||
55 | |||
56 | if format not in methods: |
||
57 | formats = ', '.join(sorted(methods.keys())) |
||
58 | raise ValueError('Unsupported format: %r, try one of %s' % (format, |
||
59 | formats)) |
||
60 | |||
61 | method = methods[format] |
||
62 | return method(keyfile) |
||
63 | |||
64 | def save_pkcs1(self, format='PEM'): |
||
65 | '''Saves the public key in PKCS#1 DER or PEM format. |
||
66 | |||
67 | :param format: the format to save; 'PEM' or 'DER' |
||
68 | :returns: the DER- or PEM-encoded public key. |
||
69 | |||
70 | ''' |
||
71 | |||
72 | methods = { |
||
73 | 'PEM': self._save_pkcs1_pem, |
||
74 | 'DER': self._save_pkcs1_der, |
||
75 | } |
||
76 | |||
77 | if format not in methods: |
||
78 | formats = ', '.join(sorted(methods.keys())) |
||
79 | raise ValueError('Unsupported format: %r, try one of %s' % (format, |
||
80 | formats)) |
||
81 | |||
82 | method = methods[format] |
||
83 | return method() |
||
84 | |||
85 | class PublicKey(AbstractKey): |
||
86 | '''Represents a public RSA key. |
||
87 | |||
88 | This key is also known as the 'encryption key'. It contains the 'n' and 'e' |
||
89 | values. |
||
90 | |||
91 | Supports attributes as well as dictionary-like access. Attribute accesss is |
||
92 | faster, though. |
||
93 | |||
94 | >>> PublicKey(5, 3) |
||
95 | PublicKey(5, 3) |
||
96 | |||
97 | >>> key = PublicKey(5, 3) |
||
98 | >>> key.n |
||
99 | 5 |
||
100 | >>> key['n'] |
||
101 | 5 |
||
102 | >>> key.e |
||
103 | 3 |
||
104 | >>> key['e'] |
||
105 | 3 |
||
106 | |||
107 | ''' |
||
108 | |||
109 | __slots__ = ('n', 'e') |
||
110 | |||
111 | def __init__(self, n, e): |
||
112 | self.n = n |
||
113 | self.e = e |
||
114 | |||
115 | def __getitem__(self, key): |
||
116 | return getattr(self, key) |
||
117 | |||
118 | def __repr__(self): |
||
119 | return u'PublicKey(%i, %i)' % (self.n, self.e) |
||
120 | |||
121 | def __eq__(self, other): |
||
122 | if other is None: |
||
123 | return False |
||
124 | |||
125 | if not isinstance(other, PublicKey): |
||
126 | return False |
||
127 | |||
128 | return self.n == other.n and self.e == other.e |
||
129 | |||
130 | def __ne__(self, other): |
||
131 | return not (self == other) |
||
132 | |||
133 | @classmethod |
||
134 | def _load_pkcs1_der(cls, keyfile): |
||
135 | r'''Loads a key in PKCS#1 DER format. |
||
136 | |||
137 | @param keyfile: contents of a DER-encoded file that contains the public |
||
138 | key. |
||
139 | @return: a PublicKey object |
||
140 | |||
141 | First let's construct a DER encoded key: |
||
142 | |||
143 | >>> import base64 |
||
144 | >>> b64der = 'MAwCBQCNGmYtAgMBAAE=' |
||
145 | >>> der = base64.decodestring(b64der) |
||
146 | |||
147 | This loads the file: |
||
148 | |||
149 | >>> PublicKey._load_pkcs1_der(der) |
||
150 | PublicKey(2367317549, 65537) |
||
151 | |||
152 | ''' |
||
153 | |||
154 | from pyasn1.codec.der import decoder |
||
155 | (priv, _) = decoder.decode(keyfile) |
||
156 | |||
157 | # ASN.1 contents of DER encoded public key: |
||
158 | # |
||
159 | # RSAPublicKey ::= SEQUENCE { |
||
160 | # modulus INTEGER, -- n |
||
161 | # publicExponent INTEGER, -- e |
||
162 | |||
163 | as_ints = tuple(int(x) for x in priv) |
||
164 | return cls(*as_ints) |
||
165 | |||
166 | def _save_pkcs1_der(self): |
||
167 | '''Saves the public key in PKCS#1 DER format. |
||
168 | |||
169 | @returns: the DER-encoded public key. |
||
170 | ''' |
||
171 | |||
172 | from pyasn1.type import univ, namedtype |
||
173 | from pyasn1.codec.der import encoder |
||
174 | |||
175 | class AsnPubKey(univ.Sequence): |
||
176 | componentType = namedtype.NamedTypes( |
||
177 | namedtype.NamedType('modulus', univ.Integer()), |
||
178 | namedtype.NamedType('publicExponent', univ.Integer()), |
||
179 | ) |
||
180 | |||
181 | # Create the ASN object |
||
182 | asn_key = AsnPubKey() |
||
183 | asn_key.setComponentByName('modulus', self.n) |
||
184 | asn_key.setComponentByName('publicExponent', self.e) |
||
185 | |||
186 | return encoder.encode(asn_key) |
||
187 | |||
188 | @classmethod |
||
189 | def _load_pkcs1_pem(cls, keyfile): |
||
190 | '''Loads a PKCS#1 PEM-encoded public key file. |
||
191 | |||
192 | The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and |
||
193 | after the "-----END RSA PUBLIC KEY-----" lines is ignored. |
||
194 | |||
195 | @param keyfile: contents of a PEM-encoded file that contains the public |
||
196 | key. |
||
197 | @return: a PublicKey object |
||
198 | ''' |
||
199 | |||
200 | der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY') |
||
201 | return cls._load_pkcs1_der(der) |
||
202 | |||
203 | def _save_pkcs1_pem(self): |
||
204 | '''Saves a PKCS#1 PEM-encoded public key file. |
||
205 | |||
206 | @return: contents of a PEM-encoded file that contains the public key. |
||
207 | ''' |
||
208 | |||
209 | der = self._save_pkcs1_der() |
||
210 | return rsa.pem.save_pem(der, 'RSA PUBLIC KEY') |
||
211 | |||
212 | class PrivateKey(AbstractKey): |
||
213 | '''Represents a private RSA key. |
||
214 | |||
215 | This key is also known as the 'decryption key'. It contains the 'n', 'e', |
||
216 | 'd', 'p', 'q' and other values. |
||
217 | |||
218 | Supports attributes as well as dictionary-like access. Attribute accesss is |
||
219 | faster, though. |
||
220 | |||
221 | >>> PrivateKey(3247, 65537, 833, 191, 17) |
||
222 | PrivateKey(3247, 65537, 833, 191, 17) |
||
223 | |||
224 | exp1, exp2 and coef don't have to be given, they will be calculated: |
||
225 | |||
226 | >>> pk = PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) |
||
227 | >>> pk.exp1 |
||
228 | 55063 |
||
229 | >>> pk.exp2 |
||
230 | 10095 |
||
231 | >>> pk.coef |
||
232 | 50797 |
||
233 | |||
234 | If you give exp1, exp2 or coef, they will be used as-is: |
||
235 | |||
236 | >>> pk = PrivateKey(1, 2, 3, 4, 5, 6, 7, 8) |
||
237 | >>> pk.exp1 |
||
238 | 6 |
||
239 | >>> pk.exp2 |
||
240 | 7 |
||
241 | >>> pk.coef |
||
242 | 8 |
||
243 | |||
244 | ''' |
||
245 | |||
246 | __slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef') |
||
247 | |||
248 | def __init__(self, n, e, d, p, q, exp1=None, exp2=None, coef=None): |
||
249 | self.n = n |
||
250 | self.e = e |
||
251 | self.d = d |
||
252 | self.p = p |
||
253 | self.q = q |
||
254 | |||
255 | # Calculate the other values if they aren't supplied |
||
256 | if exp1 is None: |
||
257 | self.exp1 = int(d % (p - 1)) |
||
258 | else: |
||
259 | self.exp1 = exp1 |
||
260 | |||
261 | if exp1 is None: |
||
262 | self.exp2 = int(d % (q - 1)) |
||
263 | else: |
||
264 | self.exp2 = exp2 |
||
265 | |||
266 | if coef is None: |
||
267 | (_, self.coef, _) = extended_gcd(q, p) |
||
268 | else: |
||
269 | self.coef = coef |
||
270 | |||
271 | def __getitem__(self, key): |
||
272 | return getattr(self, key) |
||
273 | |||
274 | def __repr__(self): |
||
275 | return u'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self |
||
276 | |||
277 | def __eq__(self, other): |
||
278 | if other is None: |
||
279 | return False |
||
280 | |||
281 | if not isinstance(other, PrivateKey): |
||
282 | return False |
||
283 | |||
284 | return (self.n == other.n and |
||
285 | self.e == other.e and |
||
286 | self.d == other.d and |
||
287 | self.p == other.p and |
||
288 | self.q == other.q and |
||
289 | self.exp1 == other.exp1 and |
||
290 | self.exp2 == other.exp2 and |
||
291 | self.coef == other.coef) |
||
292 | |||
293 | def __ne__(self, other): |
||
294 | return not (self == other) |
||
295 | |||
296 | @classmethod |
||
297 | def _load_pkcs1_der(cls, keyfile): |
||
298 | r'''Loads a key in PKCS#1 DER format. |
||
299 | |||
300 | @param keyfile: contents of a DER-encoded file that contains the private |
||
301 | key. |
||
302 | @return: a PrivateKey object |
||
303 | |||
304 | First let's construct a DER encoded key: |
||
305 | |||
306 | >>> import base64 |
||
307 | >>> b64der = 'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt' |
||
308 | >>> der = base64.decodestring(b64der) |
||
309 | |||
310 | This loads the file: |
||
311 | |||
312 | >>> PrivateKey._load_pkcs1_der(der) |
||
313 | PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) |
||
314 | |||
315 | ''' |
||
316 | |||
317 | from pyasn1.codec.der import decoder |
||
318 | (priv, _) = decoder.decode(keyfile) |
||
319 | |||
320 | # ASN.1 contents of DER encoded private key: |
||
321 | # |
||
322 | # RSAPrivateKey ::= SEQUENCE { |
||
323 | # version Version, |
||
324 | # modulus INTEGER, -- n |
||
325 | # publicExponent INTEGER, -- e |
||
326 | # privateExponent INTEGER, -- d |
||
327 | # prime1 INTEGER, -- p |
||
328 | # prime2 INTEGER, -- q |
||
329 | # exponent1 INTEGER, -- d mod (p-1) |
||
330 | # exponent2 INTEGER, -- d mod (q-1) |
||
331 | # coefficient INTEGER, -- (inverse of q) mod p |
||
332 | # otherPrimeInfos OtherPrimeInfos OPTIONAL |
||
333 | # } |
||
334 | |||
335 | if priv[0] != 0: |
||
336 | raise ValueError('Unable to read this file, version %s != 0' % priv[0]) |
||
337 | |||
338 | as_ints = tuple(int(x) for x in priv[1:9]) |
||
339 | return cls(*as_ints) |
||
340 | |||
341 | def _save_pkcs1_der(self): |
||
342 | '''Saves the private key in PKCS#1 DER format. |
||
343 | |||
344 | @returns: the DER-encoded private key. |
||
345 | ''' |
||
346 | |||
347 | from pyasn1.type import univ, namedtype |
||
348 | from pyasn1.codec.der import encoder |
||
349 | |||
350 | class AsnPrivKey(univ.Sequence): |
||
351 | componentType = namedtype.NamedTypes( |
||
352 | namedtype.NamedType('version', univ.Integer()), |
||
353 | namedtype.NamedType('modulus', univ.Integer()), |
||
354 | namedtype.NamedType('publicExponent', univ.Integer()), |
||
355 | namedtype.NamedType('privateExponent', univ.Integer()), |
||
356 | namedtype.NamedType('prime1', univ.Integer()), |
||
357 | namedtype.NamedType('prime2', univ.Integer()), |
||
358 | namedtype.NamedType('exponent1', univ.Integer()), |
||
359 | namedtype.NamedType('exponent2', univ.Integer()), |
||
360 | namedtype.NamedType('coefficient', univ.Integer()), |
||
361 | ) |
||
362 | |||
363 | # Create the ASN object |
||
364 | asn_key = AsnPrivKey() |
||
365 | asn_key.setComponentByName('version', 0) |
||
366 | asn_key.setComponentByName('modulus', self.n) |
||
367 | asn_key.setComponentByName('publicExponent', self.e) |
||
368 | asn_key.setComponentByName('privateExponent', self.d) |
||
369 | asn_key.setComponentByName('prime1', self.p) |
||
370 | asn_key.setComponentByName('prime2', self.q) |
||
371 | asn_key.setComponentByName('exponent1', self.exp1) |
||
372 | asn_key.setComponentByName('exponent2', self.exp2) |
||
373 | asn_key.setComponentByName('coefficient', self.coef) |
||
374 | |||
375 | return encoder.encode(asn_key) |
||
376 | |||
377 | @classmethod |
||
378 | def _load_pkcs1_pem(cls, keyfile): |
||
379 | '''Loads a PKCS#1 PEM-encoded private key file. |
||
380 | |||
381 | The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and |
||
382 | after the "-----END RSA PRIVATE KEY-----" lines is ignored. |
||
383 | |||
384 | @param keyfile: contents of a PEM-encoded file that contains the private |
||
385 | key. |
||
386 | @return: a PrivateKey object |
||
387 | ''' |
||
388 | |||
389 | der = rsa.pem.load_pem(keyfile, 'RSA PRIVATE KEY') |
||
390 | return cls._load_pkcs1_der(der) |
||
391 | |||
392 | def _save_pkcs1_pem(self): |
||
393 | '''Saves a PKCS#1 PEM-encoded private key file. |
||
394 | |||
395 | @return: contents of a PEM-encoded file that contains the private key. |
||
396 | ''' |
||
397 | |||
398 | der = self._save_pkcs1_der() |
||
399 | return rsa.pem.save_pem(der, 'RSA PRIVATE KEY') |
||
400 | |||
401 | |||
402 | View Code Duplication | def extended_gcd(a, b): |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
403 | """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb |
||
404 | """ |
||
405 | # r = gcd(a,b) i = multiplicitive inverse of a mod b |
||
406 | # or j = multiplicitive inverse of b mod a |
||
407 | # Neg return values for i or j are made positive mod b or a respectively |
||
408 | # Iterateive Version is faster and uses much less stack space |
||
409 | x = 0 |
||
410 | y = 1 |
||
411 | lx = 1 |
||
412 | ly = 0 |
||
413 | oa = a #Remember original a/b to remove |
||
414 | ob = b #negative values from return results |
||
415 | while b != 0: |
||
416 | q = a // b |
||
417 | (a, b) = (b, a % b) |
||
418 | (x, lx) = ((lx - (q * x)),x) |
||
419 | (y, ly) = ((ly - (q * y)),y) |
||
420 | if (lx < 0): lx += ob #If neg wrap modulo orignal b |
||
421 | if (ly < 0): ly += oa #If neg wrap modulo orignal a |
||
422 | return (a, lx, ly) #Return only positive values |
||
423 | |||
424 | def find_p_q(nbits, accurate=True): |
||
425 | ''''Returns a tuple of two different primes of nbits bits each. |
||
426 | |||
427 | The resulting p * q has exacty 2 * nbits bits, and the returned p and q |
||
428 | will not be equal. |
||
429 | |||
430 | @param nbits: the number of bits in each of p and q. |
||
431 | @param accurate: whether to enable accurate mode or not. |
||
432 | @returns (p, q), where p > q |
||
433 | |||
434 | >>> (p, q) = find_p_q(128) |
||
435 | >>> from rsa import common |
||
436 | >>> common.bit_size(p * q) |
||
437 | 256 |
||
438 | |||
439 | When not in accurate mode, the number of bits can be slightly less |
||
440 | |||
441 | >>> (p, q) = find_p_q(128, accurate=False) |
||
442 | >>> from rsa import common |
||
443 | >>> common.bit_size(p * q) <= 256 |
||
444 | True |
||
445 | >>> common.bit_size(p * q) > 240 |
||
446 | True |
||
447 | |||
448 | ''' |
||
449 | |||
450 | total_bits = nbits * 2 |
||
451 | |||
452 | # Make sure that p and q aren't too close or the factoring programs can |
||
453 | # factor n. |
||
454 | shift = nbits // 16 |
||
455 | pbits = nbits + shift |
||
456 | qbits = nbits - shift |
||
457 | |||
458 | # Choose the two initial primes |
||
459 | log.debug('find_p_q(%i): Finding p', nbits) |
||
460 | p = rsa.prime.getprime(pbits) |
||
461 | log.debug('find_p_q(%i): Finding q', nbits) |
||
462 | q = rsa.prime.getprime(qbits) |
||
463 | |||
464 | def is_acceptable(p, q): |
||
465 | '''Returns True iff p and q are acceptable: |
||
466 | |||
467 | - p and q differ |
||
468 | - (p * q) has the right nr of bits (when accurate=True) |
||
469 | ''' |
||
470 | |||
471 | if p == q: |
||
472 | return False |
||
473 | |||
474 | if not accurate: |
||
475 | return True |
||
476 | |||
477 | # Make sure we have just the right amount of bits |
||
478 | found_size = rsa.common.bit_size(p * q) |
||
479 | return total_bits == found_size |
||
480 | |||
481 | # Keep choosing other primes until they match our requirements. |
||
482 | change_p = False |
||
483 | tries = 0 |
||
484 | while not is_acceptable(p, q): |
||
485 | tries += 1 |
||
486 | # Change p on one iteration and q on the other |
||
487 | if change_p: |
||
488 | log.debug(' find another p') |
||
489 | p = rsa.prime.getprime(pbits) |
||
490 | else: |
||
491 | log.debug(' find another q') |
||
492 | q = rsa.prime.getprime(qbits) |
||
493 | |||
494 | change_p = not change_p |
||
495 | |||
496 | # We want p > q as described on |
||
497 | # http://www.di-mgt.com.au/rsa_alg.html#crt |
||
498 | return (max(p, q), min(p, q)) |
||
499 | |||
500 | def calculate_keys(p, q, nbits): |
||
501 | """Calculates an encryption and a decryption key given p and q, and |
||
502 | returns them as a tuple (e, d) |
||
503 | |||
504 | """ |
||
505 | |||
506 | phi_n = (p - 1) * (q - 1) |
||
507 | |||
508 | # A very common choice for e is 65537 |
||
509 | e = 65537 |
||
510 | |||
511 | (divider, d, _) = extended_gcd(e, phi_n) |
||
512 | |||
513 | if divider != 1: |
||
514 | raise ValueError("e (%d) and phi_n (%d) are not relatively prime" % |
||
515 | (e, phi_n)) |
||
516 | if (d < 0): |
||
517 | raise ValueError("extended_gcd shouldn't return negative values, " |
||
518 | "please file a bug") |
||
519 | if (e * d) % phi_n != 1: |
||
520 | raise ValueError("e (%d) and d (%d) are not mult. inv. modulo " |
||
521 | "phi_n (%d)" % (e, d, phi_n)) |
||
522 | |||
523 | return (e, d) |
||
524 | |||
525 | def gen_keys(nbits, accurate=True): |
||
526 | """Generate RSA keys of nbits bits. Returns (p, q, e, d). |
||
527 | |||
528 | Note: this can take a long time, depending on the key size. |
||
529 | |||
530 | @param nbits: the total number of bits in ``p`` and ``q``. Both ``p`` and |
||
531 | ``q`` will use ``nbits/2`` bits. |
||
532 | """ |
||
533 | |||
534 | (p, q) = find_p_q(nbits // 2, accurate) |
||
535 | (e, d) = calculate_keys(p, q, nbits // 2) |
||
536 | |||
537 | return (p, q, e, d) |
||
538 | |||
539 | def newkeys(nbits, accurate=True): |
||
540 | """Generates public and private keys, and returns them as (pub, priv). |
||
541 | |||
542 | The public key is also known as the 'encryption key', and is a |
||
543 | :py:class:`PublicKey` object. The private key is also known as the |
||
544 | 'decryption key' and is a :py:class:`PrivateKey` object. |
||
545 | |||
546 | :param nbits: the number of bits required to store ``n = p*q``. |
||
547 | :param accurate: when True, ``n`` will have exactly the number of bits you |
||
548 | asked for. However, this makes key generation much slower. When False, |
||
549 | `n`` may have slightly less bits. |
||
550 | |||
551 | :returns: a tuple (:py:class:`rsa.PublicKey`, :py:class:`rsa.PrivateKey`) |
||
552 | |||
553 | """ |
||
554 | |||
555 | if nbits < 16: |
||
556 | raise ValueError('Key too small') |
||
557 | |||
558 | (p, q, e, d) = gen_keys(nbits) |
||
559 | |||
560 | n = p * q |
||
561 | |||
562 | return ( |
||
563 | PublicKey(n, e), |
||
564 | PrivateKey(n, e, d, p, q) |
||
565 | ) |
||
566 | |||
567 | __all__ = ['PublicKey', 'PrivateKey', 'newkeys'] |
||
568 | |||
569 | if __name__ == '__main__': |
||
570 | import doctest |
||
571 | |||
572 | try: |
||
573 | for count in range(100): |
||
574 | (failures, tests) = doctest.testmod() |
||
575 | if failures: |
||
576 | break |
||
577 | |||
578 | if (count and count % 10 == 0) or count == 1: |
||
579 | print '%i times' % count |
||
580 | except KeyboardInterrupt: |
||
581 | print 'Aborted' |
||
582 | else: |
||
583 | print 'Doctests done' |
||
584 |