1919 Dict ,
2020 Optional ,
2121 Protocol ,
22+ Set ,
2223 Tuple ,
2324 Type ,
2425 Union ,
@@ -868,9 +869,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868869 raise ValueError ("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING." )
869870
870871 if device is None :
871- device_type = _find_device (model )
872+ device = _find_device (model )
873+ device_type = _find_device_type (model )
872874 elif isinstance (device , str ):
873875 _validate_device_type (device )
876+ import torch
877+
878+ device = torch .device (device )
874879 device_type = Device (type = device )
875880 else :
876881 device_type = Device (device .type )
@@ -884,7 +889,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884889 layer_name = module_class .kernel_layer_name
885890
886891 if _DISABLE_KERNEL_MAPPING :
887- _replace_forward (module , module_class )
892+ _replace_forward (device , module , module_class )
888893 continue
889894
890895 kernel = _KERNEL_MAPPING .get ().get (str (layer_name ))
@@ -898,7 +903,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898903 )
899904 if not use_fallback :
900905 raise ValueError (f"No layer mapping for `{ layer_name } `" )
901- _replace_forward (module , module_class )
906+ _replace_forward (device , module , module_class )
902907 continue
903908
904909 # Get kernel options for the device
@@ -909,7 +914,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909914 raise ValueError (
910915 f"No layer mapping for `{ layer_name } ` with device type `{ device_type } `"
911916 )
912- _replace_forward (module , module_class )
917+ _replace_forward (device , module , module_class )
913918 continue
914919
915920 repos = property_repos .repos
@@ -919,7 +924,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919924 raise ValueError (
920925 f"No layer mapping for `{ layer_name } ` device `{ device_type } ` with the right properties"
921926 )
922- _replace_forward (module , module_class )
927+ _replace_forward (device , module , module_class )
923928 continue
924929
925930 repo_with_mode = _select_repository (
@@ -932,7 +937,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932937 raise ValueError (
933938 f"No repository for `{ layer_name } ` for configuration mode={ mode } "
934939 )
935- _replace_forward (module , module_class )
940+ _replace_forward (device , module , module_class )
936941 continue
937942
938943 repo , repo_mode = repo_with_mode
@@ -951,6 +956,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951956 )
952957
953958 _conditionally_replace_forward (
959+ device = device ,
954960 module = module ,
955961 layer = layer ,
956962 mode = mode ,
@@ -1037,19 +1043,26 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
10371043 raise TypeError (f"{ repo } must not override nn.Module constructor." )
10381044
10391045 # ... or predefined member variables.
1040- torch_module_members = {name for name , _ in inspect .getmembers (nn .Module )}
1041- cls_members = {name for name , _ in inspect .getmembers (cls )}
1042- difference = cls_members - torch_module_members
1046+ unique_members = _unique_layer_members (cls )
10431047 # verify if : difference ⊄ {"can_torch_compile", "has_backward"}
1044- if not difference <= {"can_torch_compile" , "has_backward" }:
1048+ if not unique_members <= {
1049+ "can_torch_compile" ,
1050+ "create_state" ,
1051+ "has_backward" ,
1052+ "forward_with_state" ,
1053+ }:
10451054 raise TypeError (
10461055 f"{ repo } must not contain additional members compared to `{ check_cls .__name__ } `."
10471056 )
10481057
10491058 # Check whether the forward signatures are similar.
1050- params = inspect .signature (cls .forward ).parameters
10511059 ref_params = inspect .signature (check_cls .forward ).parameters
10521060
1061+ if _is_stateful_layer (cls ):
1062+ params = inspect .signature (cls .forward_with_state ).parameters
1063+ else :
1064+ params = inspect .signature (cls .forward ).parameters
1065+
10531066 if len (params ) != len (ref_params ):
10541067 raise TypeError (
10551068 f"Forward signature of { repo } does not match `{ check_cls .__name__ } `: different number of arguments."
@@ -1074,15 +1087,21 @@ def _is_rocm_platform():
10741087 return torch .version .hip is not None
10751088
10761089
1077- def _find_device (model : "nn.Module" ) -> Device :
1090+ def _find_device (model : "nn.Module" ) -> torch . device :
10781091 try :
10791092 param = next (model .parameters ())
10801093 except StopIteration :
10811094 raise ValueError (
10821095 "Cannot determine model device, provide as `device` argument to `kernelize`."
10831096 )
10841097
1085- dev_type = param .device .type
1098+ return param .device
1099+
1100+
1101+ def _find_device_type (model : "nn.Module" ) -> Device :
1102+ device = _find_device (model )
1103+
1104+ dev_type = device .type
10861105 if dev_type == "cuda" :
10871106 # Refine based on actual platform
10881107 if _is_rocm_platform ():
@@ -1103,6 +1122,7 @@ def _find_capability() -> int:
11031122
11041123def _conditionally_replace_forward (
11051124 * ,
1125+ device : "torch.device" ,
11061126 module : "nn.Module" ,
11071127 layer : Type ["nn.Module" ],
11081128 mode : Mode ,
@@ -1128,15 +1148,25 @@ def _conditionally_replace_forward(
11281148 logging .info ("Layer does not support torch.compile, using fallback" )
11291149 if needs_fallback_for_backward :
11301150 logging .info ("Layer does not support backward, using fallback" )
1131- _replace_forward (module , module_class )
1151+ _replace_forward (device , module , module_class )
11321152 else :
11331153 raise ValueError (f"Available kernel does not support mode: { mode } " )
11341154 else :
1135- _replace_forward (module , layer )
1155+ _replace_forward (device , module , layer )
11361156
11371157
1138- def _replace_forward (module : "nn.Module" , layer : Type ["nn.Module" ]):
1139- module .forward = MethodType (layer .forward , module ) # type: ignore[method-assign]
1158+ def _replace_forward (
1159+ device : "torch.device" , module : "nn.Module" , layer : Type ["nn.Module" ]
1160+ ):
1161+ if _is_stateful_layer (layer ):
1162+ state = layer .create_state (module , device )
1163+
1164+ def forward (self , * args , ** kwargs ):
1165+ return layer .forward_with_state (self , state , * args , ** kwargs )
1166+
1167+ module .forward = forward
1168+ else :
1169+ module .forward = MethodType (layer .forward , module ) # type: ignore[method-assign]
11401170
11411171
11421172def _validate_layer_has_mode (
@@ -1179,3 +1209,21 @@ def _get_layer_memoize(
11791209 _CACHED_LAYER [repo ] = layer
11801210
11811211 return layer
1212+
1213+
1214+ def _unique_layer_members (layer : Type ["nn.Module" ]) -> Set [str ]:
1215+ import torch .nn as nn
1216+
1217+ torch_module_members = {name for name , _ in inspect .getmembers (nn .Module )}
1218+ cls_members = {name for name , _ in inspect .getmembers (layer )}
1219+ return cls_members - torch_module_members
1220+
1221+
1222+ def _is_stateful_layer (layer : Type [nn .Module ]) -> bool :
1223+ unique = _unique_layer_members (layer )
1224+ is_stateful = "forward_with_state" in unique
1225+ if is_stateful and len (unique & {"create_state" , "forward_with_state" }) != 2 :
1226+ raise TypeError (
1227+ f"Stateful layer `{ layer .__name__ } ` must implement both `create_state` and `forward_with_state` or neither."
1228+ )
1229+ return is_stateful
0 commit comments