@@ -301,6 +301,8 @@ def __init__(self, **kwargs):
301301 are provided.
302302 """
303303 super (JWK , self ).__init__ ()
304+ self ._cache_pub_k = None
305+ self ._cache_pri_k = None
304306
305307 if 'generate' in kwargs :
306308 self .generate_key (** kwargs )
@@ -485,6 +487,8 @@ def _import_pyca_pub_okp(self, key, **params):
485487 def import_key (self , ** kwargs ):
486488 newkey = {}
487489 key_vals = 0
490+ self ._cache_pub_k = None
491+ self ._cache_pri_k = None
488492
489493 names = list (kwargs .keys ())
490494
@@ -730,57 +734,93 @@ def _check_constraints(self, usage, operation):
730734 def _decode_int (self , n ):
731735 return int (hexlify (base64url_decode (n )), 16 )
732736
733- def _rsa_pub (self ):
737+ def _rsa_pub_n (self ):
734738 e = self ._decode_int (self .get ('e' ))
735739 n = self ._decode_int (self .get ('n' ))
736740 return rsa .RSAPublicNumbers (e , n )
737741
738- def _rsa_pri (self ):
742+ def _rsa_pri_n (self ):
739743 p = self ._decode_int (self .get ('p' ))
740744 q = self ._decode_int (self .get ('q' ))
741745 d = self ._decode_int (self .get ('d' ))
742746 dp = self ._decode_int (self .get ('dp' ))
743747 dq = self ._decode_int (self .get ('dq' ))
744748 qi = self ._decode_int (self .get ('qi' ))
745- return rsa .RSAPrivateNumbers (p , q , d , dp , dq , qi , self ._rsa_pub ())
749+ return rsa .RSAPrivateNumbers (p , q , d , dp , dq , qi , self ._rsa_pub_n ())
746750
747- def _ec_pub (self , curve ):
751+ def _rsa_pub (self ):
752+ k = self ._cache_pub_k
753+ if k is None :
754+ k = self ._rsa_pub_n ().public_key (default_backend ())
755+ self ._cache_pub_k = k
756+ return k
757+
758+ def _rsa_pri (self ):
759+ k = self ._cache_pri_k
760+ if k is None :
761+ k = self ._rsa_pri_n ().private_key (default_backend ())
762+ self ._cache_pri_k = k
763+ return k
764+
765+ def _ec_pub_n (self , curve ):
748766 x = self ._decode_int (self .get ('x' ))
749767 y = self ._decode_int (self .get ('y' ))
750768 return ec .EllipticCurvePublicNumbers (x , y , self .get_curve (curve ))
751769
752- def _ec_pri (self , curve ):
770+ def _ec_pri_n (self , curve ):
753771 d = self ._decode_int (self .get ('d' ))
754- return ec .EllipticCurvePrivateNumbers (d , self ._ec_pub (curve ))
772+ return ec .EllipticCurvePrivateNumbers (d , self ._ec_pub_n (curve ))
773+
774+ def _ec_pub (self , curve ):
775+ k = self ._cache_pub_k
776+ if k is None :
777+ k = self ._ec_pub_n (curve ).public_key (default_backend ())
778+ self ._cache_pub_k = k
779+ return k
780+
781+ def _ec_pri (self , curve ):
782+ k = self ._cache_pri_k
783+ if k is None :
784+ k = self ._ec_pri_n (curve ).private_key (default_backend ())
785+ self ._cache_pri_k = k
786+ return k
755787
756788 def _okp_pub (self ):
757- crv = self .get ('crv' )
758- try :
759- pubkey = _OKP_CURVES_TABLE [crv ].pubkey
760- except KeyError as e :
761- raise InvalidJWKValue ('Unknown curve "%s"' % crv ) from e
789+ k = self ._cache_pub_k
790+ if k is None :
791+ crv = self .get ('crv' )
792+ try :
793+ pubkey = _OKP_CURVES_TABLE [crv ].pubkey
794+ except KeyError as e :
795+ raise InvalidJWKValue ('Unknown curve "%s"' % crv ) from e
762796
763- x = base64url_decode (self .get ('x' ))
764- return pubkey .from_public_bytes (x )
797+ x = base64url_decode (self .get ('x' ))
798+ k = pubkey .from_public_bytes (x )
799+ self ._cache_pub_k = k
800+ return k
765801
766802 def _okp_pri (self ):
767- crv = self .get ('crv' )
768- try :
769- privkey = _OKP_CURVES_TABLE [crv ].privkey
770- except KeyError as e :
771- raise InvalidJWKValue ('Unknown curve "%s"' % crv ) from e
803+ k = self ._cache_pri_k
804+ if k is None :
805+ crv = self .get ('crv' )
806+ try :
807+ privkey = _OKP_CURVES_TABLE [crv ].privkey
808+ except KeyError as e :
809+ raise InvalidJWKValue ('Unknown curve "%s"' % crv ) from e
772810
773- d = base64url_decode (self .get ('d' ))
774- return privkey .from_private_bytes (d )
811+ d = base64url_decode (self .get ('d' ))
812+ k = privkey .from_private_bytes (d )
813+ self ._cache_pri_k = k
814+ return k
775815
776816 def _get_public_key (self , arg = None ):
777817 ktype = self .get ('kty' )
778818 if ktype == 'oct' :
779819 return self .get ('k' )
780820 elif ktype == 'RSA' :
781- return self ._rsa_pub (). public_key ( default_backend ())
821+ return self ._rsa_pub ()
782822 elif ktype == 'EC' :
783- return self ._ec_pub (arg ). public_key ( default_backend ())
823+ return self ._ec_pub (arg )
784824 elif ktype == 'OKP' :
785825 return self ._okp_pub ()
786826 else :
@@ -791,9 +831,9 @@ def _get_private_key(self, arg=None):
791831 if ktype == 'oct' :
792832 return self .get ('k' )
793833 elif ktype == 'RSA' :
794- return self ._rsa_pri (). private_key ( default_backend ())
834+ return self ._rsa_pri ()
795835 elif ktype == 'EC' :
796- return self ._ec_pri (arg ). private_key ( default_backend ())
836+ return self ._ec_pri (arg )
797837 elif ktype == 'OKP' :
798838 return self ._okp_pri ()
799839 else :
@@ -969,6 +1009,9 @@ def __setitem__(self, item, value):
9691009
9701010 # Check if item is a key value and verify its format
9711011 if item in list (JWKValuesRegistry [kty ].keys ()):
1012+ # Invalidate cached keys if any
1013+ self ._cache_pub_k = None
1014+ self ._cache_pri_k = None
9721015 if JWKValuesRegistry [kty ][item ].type == ParmType .b64 :
9731016 try :
9741017 v = base64url_decode (value )
@@ -1028,6 +1071,12 @@ def __delitem__(self, item):
10281071 if self .get (name ) is not None :
10291072 raise KeyError ("Cannot remove 'kty', values present" )
10301073
1074+ kty = self .get ('kty' )
1075+ if kty is not None and item in list (JWKValuesRegistry [kty ].keys ()):
1076+ # Invalidate cached keys if any
1077+ self ._cache_pub_k = None
1078+ self ._cache_pri_k = None
1079+
10311080 super (JWK , self ).__delitem__ (item )
10321081
10331082 def __eq__ (self , other ):
0 commit comments