Skip to content

Commit ddff0b0

Browse files
committed
Cache pub/pri keys on retrieval
Pyca rightfully performs consistency checks when importing keys and these operations are rather expensive. So cache keys once generated so that repeated uses of the same JWK do not incur undue cost of reloading the keys from scratch for each subsequent operation. with a simple test by hand: $ python >>> from jwcrypto import jwk >>> def test(): ... key = jwk.JWK.generate(kty='RSA', size=2048) ... for i in range(1000): ... k = key._get_private_key() ... >>> import timeit Before the patch: >>> print(timeit.timeit("test()", setup="from __main__ import test", number=10)) 35.80328264506534 After the patch: >>> print(timeit.timeit("test()", setup="from __main__ import test", number=10)) 0.9109518649056554 Resolves #243 Signed-off-by: Simo Sorce <[email protected]>
1 parent 3ba7408 commit ddff0b0

File tree

1 file changed

+73
-24
lines changed

1 file changed

+73
-24
lines changed

jwcrypto/jwk.py

Lines changed: 73 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ def __init__(self, **kwargs):
301301
are provided.
302302
"""
303303
super(JWK, self).__init__()
304+
self._cache_pub_k = None
305+
self._cache_pri_k = None
304306

305307
if 'generate' in kwargs:
306308
self.generate_key(**kwargs)
@@ -485,6 +487,8 @@ def _import_pyca_pub_okp(self, key, **params):
485487
def import_key(self, **kwargs):
486488
newkey = {}
487489
key_vals = 0
490+
self._cache_pub_k = None
491+
self._cache_pri_k = None
488492

489493
names = list(kwargs.keys())
490494

@@ -730,57 +734,93 @@ def _check_constraints(self, usage, operation):
730734
def _decode_int(self, n):
731735
return int(hexlify(base64url_decode(n)), 16)
732736

733-
def _rsa_pub(self):
737+
def _rsa_pub_n(self):
734738
e = self._decode_int(self.get('e'))
735739
n = self._decode_int(self.get('n'))
736740
return rsa.RSAPublicNumbers(e, n)
737741

738-
def _rsa_pri(self):
742+
def _rsa_pri_n(self):
739743
p = self._decode_int(self.get('p'))
740744
q = self._decode_int(self.get('q'))
741745
d = self._decode_int(self.get('d'))
742746
dp = self._decode_int(self.get('dp'))
743747
dq = self._decode_int(self.get('dq'))
744748
qi = self._decode_int(self.get('qi'))
745-
return rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, self._rsa_pub())
749+
return rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, self._rsa_pub_n())
746750

747-
def _ec_pub(self, curve):
751+
def _rsa_pub(self):
752+
k = self._cache_pub_k
753+
if k is None:
754+
k = self._rsa_pub_n().public_key(default_backend())
755+
self._cache_pub_k = k
756+
return k
757+
758+
def _rsa_pri(self):
759+
k = self._cache_pri_k
760+
if k is None:
761+
k = self._rsa_pri_n().private_key(default_backend())
762+
self._cache_pri_k = k
763+
return k
764+
765+
def _ec_pub_n(self, curve):
748766
x = self._decode_int(self.get('x'))
749767
y = self._decode_int(self.get('y'))
750768
return ec.EllipticCurvePublicNumbers(x, y, self.get_curve(curve))
751769

752-
def _ec_pri(self, curve):
770+
def _ec_pri_n(self, curve):
753771
d = self._decode_int(self.get('d'))
754-
return ec.EllipticCurvePrivateNumbers(d, self._ec_pub(curve))
772+
return ec.EllipticCurvePrivateNumbers(d, self._ec_pub_n(curve))
773+
774+
def _ec_pub(self, curve):
775+
k = self._cache_pub_k
776+
if k is None:
777+
k = self._ec_pub_n(curve).public_key(default_backend())
778+
self._cache_pub_k = k
779+
return k
780+
781+
def _ec_pri(self, curve):
782+
k = self._cache_pri_k
783+
if k is None:
784+
k = self._ec_pri_n(curve).private_key(default_backend())
785+
self._cache_pri_k = k
786+
return k
755787

756788
def _okp_pub(self):
757-
crv = self.get('crv')
758-
try:
759-
pubkey = _OKP_CURVES_TABLE[crv].pubkey
760-
except KeyError as e:
761-
raise InvalidJWKValue('Unknown curve "%s"' % crv) from e
789+
k = self._cache_pub_k
790+
if k is None:
791+
crv = self.get('crv')
792+
try:
793+
pubkey = _OKP_CURVES_TABLE[crv].pubkey
794+
except KeyError as e:
795+
raise InvalidJWKValue('Unknown curve "%s"' % crv) from e
762796

763-
x = base64url_decode(self.get('x'))
764-
return pubkey.from_public_bytes(x)
797+
x = base64url_decode(self.get('x'))
798+
k = pubkey.from_public_bytes(x)
799+
self._cache_pub_k = k
800+
return k
765801

766802
def _okp_pri(self):
767-
crv = self.get('crv')
768-
try:
769-
privkey = _OKP_CURVES_TABLE[crv].privkey
770-
except KeyError as e:
771-
raise InvalidJWKValue('Unknown curve "%s"' % crv) from e
803+
k = self._cache_pri_k
804+
if k is None:
805+
crv = self.get('crv')
806+
try:
807+
privkey = _OKP_CURVES_TABLE[crv].privkey
808+
except KeyError as e:
809+
raise InvalidJWKValue('Unknown curve "%s"' % crv) from e
772810

773-
d = base64url_decode(self.get('d'))
774-
return privkey.from_private_bytes(d)
811+
d = base64url_decode(self.get('d'))
812+
k = privkey.from_private_bytes(d)
813+
self._cache_pri_k = k
814+
return k
775815

776816
def _get_public_key(self, arg=None):
777817
ktype = self.get('kty')
778818
if ktype == 'oct':
779819
return self.get('k')
780820
elif ktype == 'RSA':
781-
return self._rsa_pub().public_key(default_backend())
821+
return self._rsa_pub()
782822
elif ktype == 'EC':
783-
return self._ec_pub(arg).public_key(default_backend())
823+
return self._ec_pub(arg)
784824
elif ktype == 'OKP':
785825
return self._okp_pub()
786826
else:
@@ -791,9 +831,9 @@ def _get_private_key(self, arg=None):
791831
if ktype == 'oct':
792832
return self.get('k')
793833
elif ktype == 'RSA':
794-
return self._rsa_pri().private_key(default_backend())
834+
return self._rsa_pri()
795835
elif ktype == 'EC':
796-
return self._ec_pri(arg).private_key(default_backend())
836+
return self._ec_pri(arg)
797837
elif ktype == 'OKP':
798838
return self._okp_pri()
799839
else:
@@ -969,6 +1009,9 @@ def __setitem__(self, item, value):
9691009

9701010
# Check if item is a key value and verify its format
9711011
if item in list(JWKValuesRegistry[kty].keys()):
1012+
# Invalidate cached keys if any
1013+
self._cache_pub_k = None
1014+
self._cache_pri_k = None
9721015
if JWKValuesRegistry[kty][item].type == ParmType.b64:
9731016
try:
9741017
v = base64url_decode(value)
@@ -1028,6 +1071,12 @@ def __delitem__(self, item):
10281071
if self.get(name) is not None:
10291072
raise KeyError("Cannot remove 'kty', values present")
10301073

1074+
kty = self.get('kty')
1075+
if kty is not None and item in list(JWKValuesRegistry[kty].keys()):
1076+
# Invalidate cached keys if any
1077+
self._cache_pub_k = None
1078+
self._cache_pri_k = None
1079+
10311080
super(JWK, self).__delitem__(item)
10321081

10331082
def __eq__(self, other):

0 commit comments

Comments
 (0)