|
1 | 1 | import contextlib |
2 | 2 | import os |
3 | 3 |
|
| 4 | +import Any |
| 5 | +import Callable |
4 | 6 | import keras |
| 7 | +import ModuleType |
| 8 | +import Optional |
5 | 9 | import tensorflow as tf |
| 10 | +import Tuple |
| 11 | +import Union |
6 | 12 |
|
7 | 13 | jax: Optional[ModuleType] = None |
8 | 14 |
|
@@ -35,8 +41,10 @@ def num_replicas_in_sync(self): |
35 | 41 | return 0 |
36 | 42 | return jax.device_count("tpu") |
37 | 43 |
|
| 44 | + |
38 | 45 | StrategyType = Union[tf.distribute.Strategy, DummyStrategy, JaxDummyStrategy] |
39 | 46 |
|
| 47 | + |
40 | 48 | def get_tpu_strategy(test_case: Any) -> StrategyType: |
41 | 49 | """Get TPU strategy if on TPU, otherwise return DummyStrategy.""" |
42 | 50 | if "TPU_NAME" not in os.environ: |
@@ -71,35 +79,30 @@ def run_with_strategy( |
71 | 79 | fn: Callable[..., Any], |
72 | 80 | *args: Any, |
73 | 81 | jit_compile: bool = False, |
74 | | - **kwargs: Any |
| 82 | + **kwargs: Any, |
75 | 83 | ) -> Any: |
76 | 84 | """ |
77 | | - Final wrapper fix: Flattens allowed kwargs into positional args before |
| 85 | + Final wrapper fix: Flattens allowed kwargs into positional args before |
78 | 86 | entering tf.function to guarantee a fixed graph signature. |
79 | 87 | """ |
80 | 88 | if keras.backend.backend() == "tensorflow": |
81 | | - # Extract sample_weight and treat it as an explicit third positional argument. |
82 | | - # If not present, use a placeholder (None). |
83 | | - sample_weight_value = kwargs.get('sample_weight', None) |
| 89 | + # Extract sample_weight and treat it as an explicit third positional |
| 90 | + # argument. If not present, use a placeholder (None). |
| 91 | + sample_weight_value = kwargs.get("sample_weight", None) |
84 | 92 | all_inputs = args + (sample_weight_value,) |
85 | 93 |
|
86 | 94 | @tf.function(jit_compile=jit_compile) |
87 | 95 | def tf_function_wrapper(input_tuple: Tuple[Any, ...]) -> Any: |
88 | 96 | num_original_args = len(args) |
89 | 97 | core_args = input_tuple[:num_original_args] |
90 | 98 | sw_value = input_tuple[-1] |
91 | | - |
| 99 | + |
92 | 100 | if sw_value is not None: |
93 | 101 | all_positional_args = core_args + (sw_value,) |
94 | | - return strategy.run( |
95 | | - fn, |
96 | | - args=all_positional_args |
97 | | - ) |
| 102 | + return strategy.run(fn, args=all_positional_args) |
98 | 103 | else: |
99 | | - return strategy.run( |
100 | | - fn, |
101 | | - args=core_args |
102 | | - ) |
| 104 | + return strategy.run(fn, args=core_args) |
| 105 | + |
103 | 106 | return tf_function_wrapper(all_inputs) |
104 | 107 | else: |
105 | 108 | assert not jit_compile |
|
0 commit comments