Description
Description:
When converting a floating-point number (e.g., 1e-6) to jnp.bfloat16 and then printing it using an f-string with a format specifier (e.g., :.8), the resulting scientific notation output is truncated, specifically missing the exponent part.
Steps to Reproduce:
Run
import jax.numpy as jnp
ja = jnp.array(1e-6, dtype=jnp.bfloat16)
print(f'{ja:.8}')
Observed Behavior:
The output truncates the exponent part of the scientific notation. For example, the observed output is:
Expected Behavior:
The output should display the full scientific notation, including the exponent, even if there is some precision loss due to the bfloat16 conversion. For example, an expected output (after precision loss) would be:
System info (python version, jaxlib version, accelerator, etc.)
- JAX version: 0.7.2
- Python version: 3.12
Description
Description:
When converting a floating-point number (e.g.,
1e-6) tojnp.bfloat16and then printing it using an f-string with a format specifier (e.g.,:.8), the resulting scientific notation output is truncated, specifically missing the exponent part.Steps to Reproduce:
Run
Observed Behavior:
The output truncates the exponent part of the scientific notation. For example, the observed output is:
Expected Behavior:
The output should display the full scientific notation, including the exponent, even if there is some precision loss due to the
bfloat16conversion. For example, an expected output (after precision loss) would be:System info (python version, jaxlib version, accelerator, etc.)