|
21 | 21 |
|
22 | 22 | import jax
|
23 | 23 | from jax import core
|
24 |
| -from jax.lib import xla_client as xc |
25 | 24 | import numpy as np
|
26 | 25 | from pathwaysutils import plugin_executable
|
27 | 26 |
|
28 | 27 |
|
29 |
| -def dtype_to_etype(dtype: np.dtype) -> xc.PrimitiveType: |
| 28 | +def dtype_to_xla_primitive_type_str(dtype: np.dtype) -> str: |
30 | 29 | """Converts a numpy dtype to an xla PrimitiveType."""
|
31 | 30 | if dtype == np.dtype("bfloat16"):
|
32 |
| - return xc.PrimitiveType.BF16 |
| 31 | + return "BF16" |
33 | 32 | elif dtype == np.dtype("float32"):
|
34 |
| - return xc.PrimitiveType.F32 |
| 33 | + return "F32" |
35 | 34 | elif dtype == np.dtype("float64"):
|
36 |
| - return xc.PrimitiveType.F64 |
| 35 | + return "F64" |
37 | 36 | elif dtype == np.dtype("int8"):
|
38 |
| - return xc.PrimitiveType.S8 |
| 37 | + return "S8" |
39 | 38 | elif dtype == np.dtype("int16"):
|
40 |
| - return xc.PrimitiveType.S16 |
| 39 | + return "S16" |
41 | 40 | elif dtype == np.dtype("int32"):
|
42 |
| - return xc.PrimitiveType.S32 |
| 41 | + return "S32" |
43 | 42 | elif dtype == np.dtype("int64"):
|
44 |
| - return xc.PrimitiveType.S64 |
| 43 | + return "S64" |
45 | 44 | elif dtype == np.dtype("uint8"):
|
46 |
| - return xc.PrimitiveType.U8 |
| 45 | + return "U8" |
47 | 46 | elif dtype == np.dtype("uint16"):
|
48 |
| - return xc.PrimitiveType.U16 |
| 47 | + return "U16" |
49 | 48 | elif dtype == np.dtype("uint32"):
|
50 |
| - return xc.PrimitiveType.U32 |
| 49 | + return "U32" |
51 | 50 | elif dtype == np.dtype("uint64"):
|
52 |
| - return xc.PrimitiveType.U64 |
| 51 | + return "U64" |
53 | 52 | else:
|
54 | 53 | raise ValueError(f"Unsupported dtype: {dtype}")
|
55 | 54 |
|
@@ -91,19 +90,15 @@ def get_hlo_sharding_string(
|
91 | 90 | )
|
92 | 91 |
|
93 | 92 |
|
94 |
| -def get_shape_string( |
| 93 | +def get_shape_info( |
95 | 94 | dtype: np.dtype,
|
96 |
| - shape: Sequence[int], |
97 |
| -) -> str: |
98 |
| - """Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string.""" |
99 |
| - return base64_utf8_stringify( |
100 |
| - xc.Shape.array_shape( |
101 |
| - xc.PrimitiveType(dtype_to_etype(dtype)), |
102 |
| - shape, |
103 |
| - ) |
104 |
| - .with_major_to_minor_layout_if_absent() |
105 |
| - .to_serialized_proto() |
106 |
| - ) |
| 95 | + dimensions: Sequence[int], |
| 96 | +) -> dict[str, Union[Sequence[int], str]]: |
| 97 | + """Returns shape info in the format expected by read requests.""" |
| 98 | + return { |
| 99 | + "xla_primitive_type_str": dtype_to_xla_primitive_type_str(dtype), |
| 100 | + "dimensions": dimensions, |
| 101 | + } |
107 | 102 |
|
108 | 103 |
|
109 | 104 | def get_write_request(
|
@@ -188,7 +183,7 @@ def get_read_request(
|
188 | 183 | d = {
|
189 | 184 | "persistenceReadRequest": {
|
190 | 185 | "b64_location": string_to_base64(location_path),
|
191 |
| - "b64_shape_proto_string": get_shape_string(dtype, shape), |
| 186 | + "shape": get_shape_info(dtype, shape), |
192 | 187 | "b64_name": string_to_base64(name),
|
193 | 188 | "b64_hlo_sharding_string": get_hlo_sharding_string(
|
194 | 189 | sharding, len(shape)
|
|
0 commit comments