diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 60c8797176e8..a91428bdbe71 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -85,9 +85,9 @@ def use_kernel_forward_from_hub(layer_name: str): ) }, "npu": { - Mode.INFERENCE: LayerRepository( - repo_id="kernels-community/liger_kernels", - layer_name="LigerRMSNorm", + Mode.TRAINING: LayerRepository( + repo_id="kernels-ext-npu/rmsnorm", + layer_name="rmsnorm", ) }, },