Skip to content

Commit 58dcc97

Browse files
authored
ops: limit return of requants (Comfy-Org#12506)
This check was far too broad and the dtype is not a reliable indicator of wanting the requant (as QT returns the compute dtype as the dtype). So explictly plumb whether fp8mm wants the requant or not.
1 parent 19236ed commit 58dcc97

1 file changed

Lines changed: 8 additions & 9 deletions

File tree

comfy/ops.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
7979
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
8080

8181

82-
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
82+
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
8383
offload_stream = None
8484
xfer_dest = None
8585

@@ -170,10 +170,10 @@ def to_dequant(tensor, dtype):
170170
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
171171
x = lowvram_fn(x)
172172
if (isinstance(orig, QuantizedTensor) and
173-
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
173+
(want_requant and len(fns) == 0 or update_weight)):
174174
seed = comfy.utils.string_to_seed(s.seed_key)
175175
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
176-
if orig.dtype == dtype and len(fns) == 0:
176+
if want_requant and len(fns) == 0:
177177
#The layer actually wants our freshly saved QT
178178
x = y
179179
elif update_weight:
@@ -194,7 +194,7 @@ def to_dequant(tensor, dtype):
194194
return weight, bias, (offload_stream, device if signature is not None else None, None)
195195

196196

197-
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
197+
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
198198
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
199199
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
200200
# will add async-offload support to your cast and improve performance.
@@ -212,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
212212
non_blocking = comfy.model_management.device_supports_non_blocking(device)
213213

214214
if hasattr(s, "_v"):
215-
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
215+
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
216216

217217
if offloadable and (device != s.weight.device or
218218
(s.bias is not None and device != s.bias.device)):
@@ -850,8 +850,8 @@ def state_dict(self, *args, destination=None, prefix="", **kwargs):
850850
def _forward(self, input, weight, bias):
851851
return torch.nn.functional.linear(input, weight, bias)
852852

853-
def forward_comfy_cast_weights(self, input, compute_dtype=None):
854-
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
853+
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
854+
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
855855
x = self._forward(input, weight, bias)
856856
uncast_bias_weight(self, weight, bias, offload_stream)
857857
return x
@@ -881,8 +881,7 @@ def forward(self, input, *args, **kwargs):
881881
scale = comfy.model_management.cast_to_device(scale, input.device, None)
882882
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
883883

884-
885-
output = self.forward_comfy_cast_weights(input, compute_dtype)
884+
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
886885

887886
# Reshape output back to 3D if input was 3D
888887
if reshaped_3d:

0 commit comments

Comments
 (0)