Skip to content

Optimization: Replace (a * b).sum(axis=-1) with np.einsum('...i,...i->...') in mydot for better performance and clarity #11

@SaFE-APIOpt

Description

@SaFE-APIOpt

return (a*b).sum(axis=-1)

Current implementation:

def mydot(a, b):  # fighting with numpy to get vectorized dot product
    return (a * b).sum(axis=-1)

Recommended replacement:

def mydot(a, b):
    return np.einsum('...i,...i->...', a, b)

The current implementation performs an element-wise multiplication a * b, which creates a temporary array of the same shape. It then reduces that array by summing over the last axis. While correct, this approach incurs unnecessary memory allocation and processing overhead, especially when dealing with large arrays or deep batch dimensions.

In contrast, np.einsum('...i,...i->...', a, b) computes the same result — a dot product over the last axis — but does so more efficiently. It performs the contraction in a single C-level pass without constructing intermediate arrays, reducing both memory usage and runtime. This form is also semantically clearer: it explicitly expresses the intent of performing a dot product over the last dimension, while supporting arbitrary leading batch dimensions.

For vectorized dot products in high-performance or large-scale settings, einsum is the more efficient and idiomatic choice.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions