Skip to content

Bug: jnp.bfloat16 f-string formatting truncates exponent #341

@yoeldr

Description

@yoeldr

Description

Description:

When converting a floating-point number (e.g., 1e-6) to jnp.bfloat16 and then printing it using an f-string with a format specifier (e.g., :.8), the resulting scientific notation output is truncated, specifically missing the exponent part.

Steps to Reproduce:

Run

import jax.numpy as jnp
ja = jnp.array(1e-6, dtype=jnp.bfloat16)
print(f'{ja:.8}')

Observed Behavior:

The output truncates the exponent part of the scientific notation. For example, the observed output is:

9.98378e

Expected Behavior:

The output should display the full scientific notation, including the exponent, even if there is some precision loss due to the bfloat16 conversion. For example, an expected output (after precision loss) would be:

9.983780e-7

System info (python version, jaxlib version, accelerator, etc.)

  • JAX version: 0.7.2
  • Python version: 3.12

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions