@@ -79,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
7979 return comfy .model_management .cast_to (weight , input .dtype , input .device , non_blocking = non_blocking , copy = copy )
8080
8181
82- def cast_bias_weight_with_vbar (s , dtype , device , bias_dtype , non_blocking , compute_dtype ):
82+ def cast_bias_weight_with_vbar (s , dtype , device , bias_dtype , non_blocking , compute_dtype , want_requant ):
8383 offload_stream = None
8484 xfer_dest = None
8585
@@ -170,10 +170,10 @@ def to_dequant(tensor, dtype):
170170 #FIXME: this is not accurate, we need to be sensitive to the compute dtype
171171 x = lowvram_fn (x )
172172 if (isinstance (orig , QuantizedTensor ) and
173- (orig . dtype == dtype and len (fns ) == 0 or update_weight )):
173+ (want_requant and len (fns ) == 0 or update_weight )):
174174 seed = comfy .utils .string_to_seed (s .seed_key )
175175 y = QuantizedTensor .from_float (x , s .layout_type , scale = "recalculate" , stochastic_rounding = seed )
176- if orig . dtype == dtype and len (fns ) == 0 :
176+ if want_requant and len (fns ) == 0 :
177177 #The layer actually wants our freshly saved QT
178178 x = y
179179 elif update_weight :
@@ -194,7 +194,7 @@ def to_dequant(tensor, dtype):
194194 return weight , bias , (offload_stream , device if signature is not None else None , None )
195195
196196
197- def cast_bias_weight (s , input = None , dtype = None , device = None , bias_dtype = None , offloadable = False , compute_dtype = None ):
197+ def cast_bias_weight (s , input = None , dtype = None , device = None , bias_dtype = None , offloadable = False , compute_dtype = None , want_requant = False ):
198198 # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
199199 # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
200200 # will add async-offload support to your cast and improve performance.
@@ -212,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
212212 non_blocking = comfy .model_management .device_supports_non_blocking (device )
213213
214214 if hasattr (s , "_v" ):
215- return cast_bias_weight_with_vbar (s , dtype , device , bias_dtype , non_blocking , compute_dtype )
215+ return cast_bias_weight_with_vbar (s , dtype , device , bias_dtype , non_blocking , compute_dtype , want_requant )
216216
217217 if offloadable and (device != s .weight .device or
218218 (s .bias is not None and device != s .bias .device )):
@@ -850,8 +850,8 @@ def state_dict(self, *args, destination=None, prefix="", **kwargs):
850850 def _forward (self , input , weight , bias ):
851851 return torch .nn .functional .linear (input , weight , bias )
852852
853- def forward_comfy_cast_weights (self , input , compute_dtype = None ):
854- weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True , compute_dtype = compute_dtype )
853+ def forward_comfy_cast_weights (self , input , compute_dtype = None , want_requant = False ):
854+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True , compute_dtype = compute_dtype , want_requant = want_requant )
855855 x = self ._forward (input , weight , bias )
856856 uncast_bias_weight (self , weight , bias , offload_stream )
857857 return x
@@ -881,8 +881,7 @@ def forward(self, input, *args, **kwargs):
881881 scale = comfy .model_management .cast_to_device (scale , input .device , None )
882882 input = QuantizedTensor .from_float (input_reshaped , self .layout_type , scale = scale )
883883
884-
885- output = self .forward_comfy_cast_weights (input , compute_dtype )
884+ output = self .forward_comfy_cast_weights (input , compute_dtype , want_requant = isinstance (input , QuantizedTensor ))
886885
887886 # Reshape output back to 3D if input was 3D
888887 if reshaped_3d :
0 commit comments