Skip to content

Commit 5691b4d

Browse files
committed
format
1 parent c4df618 commit 5691b4d

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

keras_rs/src/utils/tpu_test_utils.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
import contextlib
22
import os
33

4+
import Any
5+
import Callable
46
import keras
7+
import ModuleType
8+
import Optional
59
import tensorflow as tf
10+
import Tuple
11+
import Union
612

713
jax: Optional[ModuleType] = None
814

@@ -35,8 +41,10 @@ def num_replicas_in_sync(self):
3541
return 0
3642
return jax.device_count("tpu")
3743

44+
3845
StrategyType = Union[tf.distribute.Strategy, DummyStrategy, JaxDummyStrategy]
3946

47+
4048
def get_tpu_strategy(test_case: Any) -> StrategyType:
4149
"""Get TPU strategy if on TPU, otherwise return DummyStrategy."""
4250
if "TPU_NAME" not in os.environ:
@@ -71,35 +79,30 @@ def run_with_strategy(
7179
fn: Callable[..., Any],
7280
*args: Any,
7381
jit_compile: bool = False,
74-
**kwargs: Any
82+
**kwargs: Any,
7583
) -> Any:
7684
"""
77-
Final wrapper fix: Flattens allowed kwargs into positional args before
85+
Final wrapper fix: Flattens allowed kwargs into positional args before
7886
entering tf.function to guarantee a fixed graph signature.
7987
"""
8088
if keras.backend.backend() == "tensorflow":
81-
# Extract sample_weight and treat it as an explicit third positional argument.
82-
# If not present, use a placeholder (None).
83-
sample_weight_value = kwargs.get('sample_weight', None)
89+
# Extract sample_weight and treat it as an explicit third positional
90+
# argument. If not present, use a placeholder (None).
91+
sample_weight_value = kwargs.get("sample_weight", None)
8492
all_inputs = args + (sample_weight_value,)
8593

8694
@tf.function(jit_compile=jit_compile)
8795
def tf_function_wrapper(input_tuple: Tuple[Any, ...]) -> Any:
8896
num_original_args = len(args)
8997
core_args = input_tuple[:num_original_args]
9098
sw_value = input_tuple[-1]
91-
99+
92100
if sw_value is not None:
93101
all_positional_args = core_args + (sw_value,)
94-
return strategy.run(
95-
fn,
96-
args=all_positional_args
97-
)
102+
return strategy.run(fn, args=all_positional_args)
98103
else:
99-
return strategy.run(
100-
fn,
101-
args=core_args
102-
)
104+
return strategy.run(fn, args=core_args)
105+
103106
return tf_function_wrapper(all_inputs)
104107
else:
105108
assert not jit_compile

0 commit comments

Comments
 (0)