Skip to content

Commit a45b5d9

Browse files
Pathways-on-Cloud Teamcopybara-github
authored andcommitted
Drop xc symbol dependency for persistence reads.
PiperOrigin-RevId: 719471786
1 parent a708d3a commit a45b5d9

File tree

1 file changed

+21
-26
lines changed

1 file changed

+21
-26
lines changed

pathwaysutils/persistence/helper.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,34 @@
2121

2222
import jax
2323
from jax import core
24-
from jax.lib import xla_client as xc
2524
import numpy as np
2625
from pathwaysutils import plugin_executable
2726

2827

29-
def dtype_to_etype(dtype: np.dtype) -> xc.PrimitiveType:
28+
def dtype_to_xla_primitive_type_str(dtype: np.dtype) -> str:
3029
"""Converts a numpy dtype to an xla PrimitiveType."""
3130
if dtype == np.dtype("bfloat16"):
32-
return xc.PrimitiveType.BF16
31+
return "BF16"
3332
elif dtype == np.dtype("float32"):
34-
return xc.PrimitiveType.F32
33+
return "F32"
3534
elif dtype == np.dtype("float64"):
36-
return xc.PrimitiveType.F64
35+
return "F64"
3736
elif dtype == np.dtype("int8"):
38-
return xc.PrimitiveType.S8
37+
return "S8"
3938
elif dtype == np.dtype("int16"):
40-
return xc.PrimitiveType.S16
39+
return "S16"
4140
elif dtype == np.dtype("int32"):
42-
return xc.PrimitiveType.S32
41+
return "S32"
4342
elif dtype == np.dtype("int64"):
44-
return xc.PrimitiveType.S64
43+
return "S64"
4544
elif dtype == np.dtype("uint8"):
46-
return xc.PrimitiveType.U8
45+
return "U8"
4746
elif dtype == np.dtype("uint16"):
48-
return xc.PrimitiveType.U16
47+
return "U16"
4948
elif dtype == np.dtype("uint32"):
50-
return xc.PrimitiveType.U32
49+
return "U32"
5150
elif dtype == np.dtype("uint64"):
52-
return xc.PrimitiveType.U64
51+
return "U64"
5352
else:
5453
raise ValueError(f"Unsupported dtype: {dtype}")
5554

@@ -91,19 +90,15 @@ def get_hlo_sharding_string(
9190
)
9291

9392

94-
def get_shape_string(
93+
def get_shape_info(
9594
dtype: np.dtype,
96-
shape: Sequence[int],
97-
) -> str:
98-
"""Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string."""
99-
return base64_utf8_stringify(
100-
xc.Shape.array_shape(
101-
xc.PrimitiveType(dtype_to_etype(dtype)),
102-
shape,
103-
)
104-
.with_major_to_minor_layout_if_absent()
105-
.to_serialized_proto()
106-
)
95+
dimensions: Sequence[int],
96+
) -> dict[str, Union[Sequence[int], str]]:
97+
"""Returns shape info in the format expected by read requests."""
98+
return {
99+
"xla_primitive_type_str": dtype_to_xla_primitive_type_str(dtype),
100+
"dimensions": dimensions,
101+
}
107102

108103

109104
def get_write_request(
@@ -188,7 +183,7 @@ def get_read_request(
188183
d = {
189184
"persistenceReadRequest": {
190185
"b64_location": string_to_base64(location_path),
191-
"b64_shape_proto_string": get_shape_string(dtype, shape),
186+
"shape": get_shape_info(dtype, shape),
192187
"b64_name": string_to_base64(name),
193188
"b64_hlo_sharding_string": get_hlo_sharding_string(
194189
sharding, len(shape)

0 commit comments

Comments
 (0)