Skip to content
14 changes: 2 additions & 12 deletions eth_abi/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@

from eth_abi.decoding import (
ContextFramesBytesIO,
TupleDecoder,
)
from eth_abi.encoding import (
TupleEncoder,
)
from eth_abi.exceptions import (
EncodingError,
Expand Down Expand Up @@ -68,9 +64,7 @@ def encode(self, types: Iterable[TypeStr], args: Iterable[Any]) -> bytes:
validate_list_like_param(types, "types")
validate_list_like_param(args, "args")

encoders = [self._registry.get_encoder(type_str) for type_str in types]

encoder = TupleEncoder(encoders=encoders)
encoder = self._registry.get_tuple_encoder(*types)

return encoder(args)

Expand Down Expand Up @@ -152,11 +146,7 @@ def decode(
validate_list_like_param(types, "types")
validate_bytes_param(data, "data")

decoders = [
self._registry.get_decoder(type_str, strict=strict) for type_str in types
]

decoder = TupleDecoder(decoders=decoders)
decoder = self._registry.get_tuple_decoder(*types, strict=strict)
stream = self.stream_class(data)

return cast(Tuple[Any, ...], decoder(stream))
Expand Down
38 changes: 33 additions & 5 deletions eth_abi/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Optional,
Type,
Union,
cast,
)

from eth_typing import (
Expand Down Expand Up @@ -281,6 +282,7 @@ def _clear_encoder_cache(old_method: Callable[..., None]) -> Callable[..., None]
@functools.wraps(old_method)
def new_method(self: "ABIRegistry", *args: Any, **kwargs: Any) -> None:
self.get_encoder.cache_clear()
self.get_tuple_encoder.cache_clear()
return old_method(self, *args, **kwargs)

return new_method
Expand All @@ -290,6 +292,7 @@ def _clear_decoder_cache(old_method: Callable[..., None]) -> Callable[..., None]
@functools.wraps(old_method)
def new_method(self: "ABIRegistry", *args: Any, **kwargs: Any) -> None:
self.get_decoder.cache_clear()
self.get_tuple_decoder.cache_clear()
return old_method(self, *args, **kwargs)

return new_method
Expand Down Expand Up @@ -346,6 +349,12 @@ def __init__(self):
self._decoders = PredicateMapping("decoder registry")
self.get_encoder = functools.lru_cache(maxsize=None)(self._get_encoder_uncached)
self.get_decoder = functools.lru_cache(maxsize=None)(self._get_decoder_uncached)
self.get_tuple_encoder = functools.lru_cache(maxsize=None)(
self._get_tuple_encoder_uncached
)
self.get_tuple_decoder = functools.lru_cache(maxsize=None)(
self._get_tuple_decoder_uncached
)

def _get_registration(self, mapping, type_str):
coder = super()._get_registration(mapping, type_str)
Expand Down Expand Up @@ -453,8 +462,16 @@ def unregister(self, label: Optional[str]) -> None:
self.unregister_encoder(label)
self.unregister_decoder(label)

def _get_encoder_uncached(self, type_str):
return self._get_registration(self._encoders, type_str)
def _get_encoder_uncached(self, type_str: abi.TypeStr) -> Encoder:
return cast(Encoder, self._get_registration(self._encoders, type_str))

def _get_tuple_encoder_uncached(
self,
*type_strs: abi.TypeStr,
) -> encoding.TupleEncoder:
return encoding.TupleEncoder(
encoders=[self.get_encoder(type_str) for type_str in type_strs]
)

def has_encoder(self, type_str: abi.TypeStr) -> bool:
"""
Expand All @@ -472,17 +489,28 @@ def has_encoder(self, type_str: abi.TypeStr) -> bool:

return True

def _get_decoder_uncached(self, type_str, strict=True):
decoder = self._get_registration(self._decoders, type_str)
def _get_decoder_uncached(
self, type_str: abi.TypeStr, strict: bool = True
) -> Decoder:
decoder = cast(Decoder, self._get_registration(self._decoders, type_str))

if hasattr(decoder, "is_dynamic") and decoder.is_dynamic:
if getattr(decoder, "is_dynamic", False):
# Set a transient flag each time a call is made to ``get_decoder()``.
# Only dynamic decoders should be allowed these looser constraints. All
# other decoders should keep the default value of ``True``.
decoder.strict = strict

return decoder

def _get_tuple_decoder_uncached(
self,
*type_strs: abi.TypeStr,
strict: bool = True,
) -> decoding.TupleDecoder:
return decoding.TupleDecoder(
decoders=[self.get_decoder(type_str, strict) for type_str in type_strs]
)

def copy(self):
"""
Copies a registry such that new registrations can be made or existing
Expand Down