ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:
bfloat16: an alternative to the standardfloat16formatfloat8_*: several experimental 8-bit floating point representations including:float8_e4m3b11fnuzfloat8_e4m3fnfloat8_e4m3fnuzfloat8_e5m2float8_e5m2fnuz
int4anduint4: low precision integer types.
See below for specifications of these number formats.
The ml_dtypes package is tested with Python versions 3.9-3.12, and can be installed
with the following command:
pip install ml_dtypes
To test your installation, you can run the following:
pip install absl-py pytest
pytest --pyargs ml_dtypes
To build from source, clone the repository and run:
git submodule init
git submodule update
pip install .
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)Importing ml_dtypes also registers the data types with numpy, so that they may
be referred to by their string name:
>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)A bfloat16 number is a single-precision float truncated at 16 bits.
Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
Exponent: 4, Mantissa: 3, bias: 11.
Extended range: no inf, NaN represented by 0b1000'0000.
Exponent: 4, Mantissa: 3, bias: 7.
Extended range: no inf, NaN represented by 0bS111'1111.
The fn suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The f indicates it is finite values only. The n indicates it includes NaNs, but only at the outer range.
8-bit floating point with 3 bit mantissa.
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix fnuz is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F is for "finite" (no infinities), N for with special NaN encoding, UZ for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E4M3 -
0bSEEEEMMM - exponent bias: 8
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000 - denormals when exponent is 0
Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.
8-bit floating point with 2 bit mantissa.
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix fnuz is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F is for "finite" (no infinities), N for with special NaN encoding, UZ for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E5M2 -
0bSEEEEEMM - exponent bias: 16
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000 - denormals when exponent is 0
4-bit integer types, where each element is represented unpacked (i.e., padded up to a byte in memory).
NumPy does not support types smaller than a single byte. For example, the
distance between adjacent elements in an array (.strides) is expressed in
bytes. Relaxing this restriction would be a considerable engineering project.
The int4 and uint4 types therefore use an unpacked representation, where
each element of the array is padded up to a byte in memory. The lower four bits
of each byte contain the representation of the number, whereas the upper four
bits are ignored.
If you're exploring the use of low-precision dtypes in your code, you should be
careful to anticipate when the precision loss might lead to surprising results.
One example is the behavior of aggregations like sum; consider this bfloat16
summation in NumPy (run with version 1.24.2):
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256The true sum should be close to 5000, but numpy returns exactly 256: this is
because bfloat16 does not have the precision to increment 256 by values less than
1:
>>> bfloat16(256) + bfloat16(1)
256After 256, the next representable value in bfloat16 is 258:
>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258For better results you can specify that the accumulation should happen in a
higher-precision type like float32:
>>> vals.sum(dtype='float32').astype(bfloat16)
4992In contrast to NumPy, projects like JAX which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically:
>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)This is not an officially supported Google product.
The ml_dtypes source code is licensed under the Apache 2.0 license
(see LICENSE). Pre-compiled wheels are built with the
EIGEN project, which is released under the
MPL 2.0 license (see LICENSE.eigen).