Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
906 changes: 455 additions & 451 deletions docs/source/01-decode.ipynb

Large diffs are not rendered by default.

1,335 changes: 844 additions & 491 deletions docs/source/02-value-stats.ipynb

Large diffs are not rendered by default.

13,944 changes: 11,650 additions & 2,294 deletions docs/source/03-value-tables.ipynb

Large diffs are not rendered by default.

23 changes: 15 additions & 8 deletions docs/source/04-benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -34,17 +34,24 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:2025-08-20 15:40:01,949:jax._src.xla_bridge:872: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"GFloat scalar : 7510.22 nsec (25 runs at size 10000)\n",
"GFloat vectorized, numpy arrays: 43.82 nsec (25 runs at size 1000000)\n",
"GFloat vectorized, JAX JIT : 2.69 nsec (500 runs at size 1000000)\n",
"ML_dtypes : 2.57 nsec (500 runs at size 1000000)\n"
"GFloat scalar : 2605.38 nsec (50 runs at size 10000)\n",
"GFloat vectorized, numpy arrays: 50.20 nsec (25 runs at size 1000000)\n",
"GFloat vectorized, JAX JIT : 3.79 nsec (500 runs at size 1000000)\n",
"ML_dtypes : 2.60 nsec (500 runs at size 1000000)\n"
]
}
],
Expand Down Expand Up @@ -101,7 +108,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "gfloat",
"language": "python",
"name": "python3"
},
Expand All @@ -115,7 +122,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
37 changes: 17 additions & 20 deletions docs/source/05-stochastic-rounding.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ formats in Python. Headline features:
Provided Formats
----------------

Formats are parameterized by the primary IEEE-754 parameters of:
Formats are parameterized by the primary parameters of:

* Width in bits (k)
* Precision (p)
* Maximum exponent (emax)
* Exponent bias (bias)

with additional fields defining the presence/encoding of:

* Infinities
* Domain (Finite vs Extended)
* Signed/unsigned
* Not-a-number (NaN) values
* Negative zero
* Subnormal numbers
* Signed/unsigned
* Two's complement encoding (of the significand)

This allows an implementation of generic floating point encode/decode logic,
Expand Down
2 changes: 1 addition & 1 deletion src/gfloat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .round_ndarray import round_ndarray
from .encode_ndarray import encode_ndarray
from .decode_ndarray import decode_ndarray
from .types import FloatClass, FloatValue, FormatInfo, RoundMode
from .types import FloatClass, FloatValue, FormatInfo, Domain, RoundMode

# Don't automatically import from .formats.
# If the user wants them in their namespace, they can explicitly import
Expand Down
27 changes: 15 additions & 12 deletions src/gfloat/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from .types import FloatClass, FloatValue, FormatInfo
from .types import FloatClass, FloatValue, FormatInfo, Domain


def decode_float(fi: FormatInfo, i: int) -> FloatValue:
Expand Down Expand Up @@ -46,29 +46,32 @@ def decode_float(fi: FormatInfo, i: int) -> FloatValue:
if fi.is_twos_complement and signbit:
significand = (1 << t) - significand

expBias = fi.expBias
bias = fi.bias

iszero = exp == 0 and significand == 0 and fi.has_zero
issubnormal = fi.has_subnormals and (exp == 0) and (significand != 0)
isnormal = not iszero and not issubnormal
if iszero or issubnormal:
expval = 1 - expBias
expval = 1 - bias
fsignificand = significand * 2**-t
else:
expval = exp - expBias
expval = exp - bias
fsignificand = 1.0 + significand * 2**-t

# Handle specials: Infs, NaN, -0, NaN_0
signed_infinity = -np.inf if signbit else np.inf

# High NaNs
fval = None
# All-bits-special exponent (ABSE)
if w > 0 and exp == 2**w - 1:
min_i_with_nan = 2 ** (p - 1) - fi.num_high_nans
if significand >= min_i_with_nan:
fval = np.nan
if fi.has_infs and significand == min_i_with_nan - 1:
fval = signed_infinity
max_positive_code = (1 << (k - fi.signBits)) - 1
code_without_sign = i & max_positive_code
if code_without_sign > max_positive_code - fi.num_high_nans:
# Return nan, ignore sign
fval = np.nan

# Infinities
if fi.domain == Domain.Extended:
if code_without_sign == max_positive_code - fi.num_high_nans:
fval = -np.inf if signbit else np.inf

# Negative zero or NaN
if iszero and i == signmask and not fi.is_twos_complement:
Expand Down
13 changes: 7 additions & 6 deletions src/gfloat/decode_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from types import ModuleType
import numpy as np
import numpy.typing as npt
from .types import FormatInfo
from .types import FormatInfo, Domain


def decode_ndarray(
Expand Down Expand Up @@ -47,16 +47,17 @@ def decode_ndarray(
if fi.is_twos_complement:
significand = np.where(sign < 0, (1 << t) - significand, significand)

expBias = fi.expBias
bias = fi.bias

fval = np.zeros_like(codes, dtype=np.float64)
isspecial = np.zeros_like(codes, dtype=bool)

if fi.has_infs:
if fi.domain == Domain.Extended:
fval = np.where(codes == fi.code_of_posinf, np.inf, fval)
isspecial |= codes == fi.code_of_posinf
fval = np.where(codes == fi.code_of_neginf, -np.inf, fval)
isspecial |= codes == fi.code_of_neginf
if fi.is_signed:
fval = np.where(codes == fi.code_of_neginf, -np.inf, fval)
isspecial |= codes == fi.code_of_neginf

if fi.num_nans > 0:
code_is_nan = codes == fi.code_of_nan
Expand All @@ -76,7 +77,7 @@ def decode_ndarray(
fval = np.where(iszero & (sign < 0), -0.0, fval)

issubnormal = (exp == 0) & (significand != 0) & fi.has_subnormals
expval = np.where(issubnormal, 1 - expBias, exp - expBias)
expval = np.where(issubnormal, 1 - bias, exp - bias)
fsignificand = np.where(issubnormal, 0.0, 1.0) + np.ldexp(significand, -t)

# Normal/Subnormal/Zero case, other values will be overwritten
Expand Down
10 changes: 5 additions & 5 deletions src/gfloat/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from .types import FormatInfo
from .types import FormatInfo, Domain


def encode_float(fi: FormatInfo, v: float) -> int:
Expand Down Expand Up @@ -36,14 +36,14 @@ def encode_float(fi: FormatInfo, v: float) -> int:

# Overflow/underflow
if v > fi.max:
if fi.has_infs:
if fi.domain == Domain.Extended:
return fi.code_of_posinf
if fi.num_nans > 0:
return fi.code_of_nan
return fi.code_of_max

if v < fi.min:
if fi.has_infs:
if fi.domain == Domain.Extended:
return fi.code_of_neginf
if fi.num_nans > 0:
return fi.code_of_nan
Expand All @@ -65,12 +65,12 @@ def encode_float(fi: FormatInfo, v: float) -> int:
exp -= 1
# now sig in range [1, 2)

biased_exp = exp + fi.expBias
biased_exp = exp + fi.bias
if biased_exp < 1 and fi.has_subnormals:
# subnormal
sig *= 2.0 ** (biased_exp - 1)
biased_exp = 0
assert vpos == sig * 2 ** (1 - fi.expBias)
assert vpos == sig * 2 ** (1 - fi.bias)
else:
if sig > 0:
sig -= 1.0
Expand Down
12 changes: 7 additions & 5 deletions src/gfloat/encode_ndarray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

from .types import FormatInfo
from .types import FormatInfo, Domain
import numpy as np
import numpy.typing as npt

Expand Down Expand Up @@ -40,12 +40,14 @@ def encode_ndarray(fi: FormatInfo, v: npt.NDArray) -> npt.NDArray:
else:
assert not np.any(nan_mask)

if fi.has_infs:
if fi.domain == Domain.Extended:
code[v > fi.max] = fi.code_of_posinf
code[v < fi.min] = fi.code_of_neginf
if fi.is_signed:
code[v < fi.min] = fi.code_of_neginf
else:
code[v > fi.max] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_max
code[v < fi.min] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_min
if fi.is_signed:
code[v < fi.min] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_min

if fi.has_zero:
if fi.has_nz:
Expand All @@ -61,7 +63,7 @@ def encode_ndarray(fi: FormatInfo, v: npt.NDArray) -> npt.NDArray:

sig, exp = np.frexp(finite_vpos)

biased_exp = exp.astype(np.int64) + (fi.expBias - 1)
biased_exp = exp.astype(np.int64) + (fi.bias - 1)
subnormal_mask = (biased_exp < 1) & fi.has_subnormals

biased_exp_safe = np.where(subnormal_mask, biased_exp, 0)
Expand Down
Loading