@@ -1106,6 +1106,171 @@ def __init__(self, var1, var2, var3=None, **kwargs):
11061106 with self .assertRaises (NotImplementedError ):
11071107 config = layer .get_config ()
11081108
1109+ def test_call_context_args_with_custom_layers_propagates_args (self ):
1110+ class Inner (layers .Layer ):
1111+ def __init__ (self ):
1112+ super ().__init__ ()
1113+ self ._register_call_context_args ("foo_mode" )
1114+
1115+ def call (self , x , foo_mode = None ):
1116+ return x + (1 if foo_mode else 0 )
1117+
1118+ class Outer (layers .Layer ):
1119+ def __init__ (self ):
1120+ super ().__init__ ()
1121+ self ._register_call_context_args ("foo_mode" )
1122+ self .inner = Inner ()
1123+
1124+ def call (self , x ):
1125+ # Outer doesn’t even need to re‑inject explicitly:
1126+ # our base class will propagate foo_mode automatically
1127+ return self .inner (x )
1128+
1129+ layer = Outer ()
1130+ self .assertEqual (int (layer (np .array (0 ), foo_mode = True )), 1 )
1131+ self .assertEqual (int (layer (np .array (0 ))), 0 )
1132+
1133+ def test_register_call_context_arguments_success (self ):
1134+ """Validate that registering call-context args works as expected."""
1135+
1136+ class MyLayer (layers .Layer ):
1137+ def call (self , x ):
1138+ return x
1139+
1140+ layer = MyLayer ()
1141+
1142+ layer ._register_call_context_args ("foo_mode" )
1143+
1144+ self .assertCountEqual (
1145+ layer ._call_context_args , ("foo_mode" , "training" )
1146+ )
1147+
1148+ def test_register_call_context_arguments_after_call_raises_error (self ):
1149+ """Validate that registering call-context args after the layer has
1150+ been called raises an error."""
1151+
1152+ class MyLayer (layers .Layer ):
1153+ def call (self , x ):
1154+ return x
1155+
1156+ layer = MyLayer ()
1157+ layer (np .array (0 ))
1158+ with self .assertRaisesRegex (
1159+ RuntimeError ,
1160+ "Cannot add call-context args after the layer has been called." ,
1161+ ):
1162+ layer ._register_call_context_args ("foo_mode" )
1163+
1164+ def test_nested_context_args_follow_priority_order (self ):
1165+ """Validate that call-context args are propagated correctly
1166+ through multiple layers, and that the most specific value is used
1167+ when multiple values are passed down the call-stack.
1168+ """
1169+
1170+ class Inner (base_layer .Layer ):
1171+ def __init__ (self ):
1172+ super ().__init__ (name = "inner_layer" )
1173+ self ._register_call_context_args ("foo_mode" )
1174+
1175+ def call (self , inputs , foo_mode = None ):
1176+ return inputs + (1 if foo_mode else 0 )
1177+
1178+ class Middle (base_layer .Layer ):
1179+ def __init__ (self ):
1180+ super ().__init__ (name = "middle_layer" )
1181+ self ._inner_layer = Inner ()
1182+
1183+ def call (self , inputs ):
1184+ return self ._inner_layer (inputs )
1185+
1186+ class Outer (base_layer .Layer ):
1187+ def __init__ (self ):
1188+ super ().__init__ (name = "outer_layer" )
1189+ self ._middle = Middle ()
1190+
1191+ def call (self , inputs ):
1192+ return self ._middle (inputs )
1193+
1194+ layer = Outer ()
1195+ layer ._register_call_context_args ("foo_mode" )
1196+
1197+ # The value of foo_mode is set to True in the call to Outer,
1198+ # so it should automatically propagate to Inner through Middle.
1199+ self .assertEqual (int (layer (np .array (0 ), foo_mode = True )), 1 )
1200+ self .assertEqual (int (layer (np .array (0 ))), 0 )
1201+
1202+ def test_context_arg_propagation_without_declaration_does_not_resolve (self ):
1203+ """Validate that layer does not resolve a propagated arg if it is not
1204+ declared as a call-context arg in the layer itself."""
1205+
1206+ class Inner (layers .Layer ):
1207+ def call (self , x , foo_mode = None ):
1208+ return x + (1 if foo_mode else 0 )
1209+
1210+ class Wrapper (layers .Layer ):
1211+ def __init__ (self ):
1212+ super ().__init__ ()
1213+ self .inner = Inner ()
1214+
1215+ def call (self , x ):
1216+ return self .inner (x )
1217+
1218+ layer = Wrapper ()
1219+ layer ._register_call_context_args ("foo_mode" )
1220+
1221+ # The value of foo_mode is set to True in the call to Wrapper,
1222+ # However, it is not declared as a call-context arg in Inner,
1223+ # so it should not resolve to True inside Inner (and instead
1224+ # default to False).
1225+ self .assertEqual (int (layer (np .array (0 ), foo_mode = True )), 0 )
1226+
1227+ def test_call_context_args_with_models_as_layers_propagates_args (self ):
1228+ """Validate that call-context args are propagated correctly
1229+ through functional and sequential models when used as layers.
1230+ """
1231+
1232+ class InnerLayer (base_layer .Layer ):
1233+ def __init__ (self ):
1234+ super ().__init__ (name = "inner_layer" )
1235+ self ._register_call_context_args ("foo" )
1236+
1237+ def call (self , inputs , foo = None ):
1238+ if foo :
1239+ return inputs + 1.0
1240+ return inputs
1241+
1242+ class OuterLayer (base_layer .Layer ):
1243+ def __init__ (self ):
1244+ super ().__init__ (name = "outer_layer" )
1245+ self ._inner_layer = InnerLayer ()
1246+
1247+ def call (self , inputs ):
1248+ return self ._inner_layer (inputs )
1249+
1250+ sample_input = tf .constant ([[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = "float32" )
1251+
1252+ # Sequential model
1253+ seq = sequential .Sequential ([OuterLayer ()])
1254+ seq ._register_call_context_args ("foo" )
1255+
1256+ out_true = seq (sample_input , foo = True )
1257+ self .assertAllEqual (out_true , sample_input + 1.0 )
1258+
1259+ out_false = seq (sample_input , foo = False )
1260+ self .assertAllEqual (out_false , sample_input )
1261+
1262+ # Functional model
1263+ inp = input_layer .Input ((2 ,))
1264+ outer = OuterLayer ()(inp )
1265+ model = training_lib .Model (inputs = [inp ], outputs = [outer ])
1266+ model ._register_call_context_args ("foo" )
1267+
1268+ out_true = model (sample_input , foo = True )
1269+ self .assertAllEqual (out_true , sample_input + 1.0 )
1270+
1271+ out_false = model (sample_input , foo = False )
1272+ self .assertAllEqual (out_false , sample_input )
1273+
11091274
11101275@test_utils .run_v2_only
11111276class SymbolicSupportTest (test_combinations .TestCase ):
0 commit comments