diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index e7c63d81075..830249ffb10 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -404,7 +404,7 @@ def set_module_tensor_to_device( module.weight = module.weight.cuda(device_index) # clean pre and post forward hook - if clear_cache and device != "cpu": + if clear_cache and device not in ("cpu", "meta"): clear_device_cache() # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in