@@ -288,10 +288,10 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor,
288
288
class MatMulLowBit (torch .autograd .Function ):
289
289
290
290
@staticmethod
291
- def forward (ctx , A , weight ):
291
+ def forward (ctx , A , weight , input_seq_size ):
292
292
ctx .is_empty = False
293
293
import linear_q4_0
294
- result = linear_q4_0 .forward_new (A , weight .data , weight .qtype )
294
+ result = linear_q4_0 .forward_new (A , weight .data , weight .qtype , input_seq_size )
295
295
if any (ctx .needs_input_grad [:2 ]):
296
296
ctx .tensors = (A , weight )
297
297
else :
@@ -304,14 +304,14 @@ def backward(ctx, grad_output):
304
304
if ctx .is_empty :
305
305
bias_grad = None if ctx .bias is None else torch .zeros_like (ctx .bias )
306
306
return torch .zeros_like (ctx .A ), torch .zeros_like (ctx .B ), None , bias_grad , None
307
- req_gradA , _ = ctx .needs_input_grad
307
+ req_gradA , _ , _ = ctx .needs_input_grad
308
308
A , weight = ctx .tensors
309
309
grad_A , grad_weight = None , None
310
310
if req_gradA :
311
311
dequant_weight = linear_q4_0 .dequant (A , weight .data , weight .qtype )
312
312
grad_A = torch .matmul (grad_output , dequant_weight .reshape (weight ._shape ))
313
313
314
- return grad_A , grad_weight
314
+ return grad_A , grad_weight , None
315
315
316
316
317
317
class LowBitLinear (nn .Linear ):
@@ -353,10 +353,12 @@ def forward(self, x: torch.Tensor):
353
353
# disable the conversion when training
354
354
if self .conver_to_half and x_2d .shape [0 ] > 1 and x_2d .dtype == torch .float32 :
355
355
x_2d = x_2d .half ()
356
+ input_seq_size = x_shape [1 ]
356
357
if self .training and x_2d .requires_grad :
357
- result = MatMulLowBit .apply (x_2d , self .weight )
358
+ result = MatMulLowBit .apply (x_2d , self .weight , input_seq_size )
358
359
else :
359
- result = linear_q4_0 .forward_new (x_2d , self .weight .data , self .weight .qtype )
360
+ result = linear_q4_0 .forward_new (x_2d , self .weight .data , self .weight .qtype ,
361
+ input_seq_size )
360
362
new_shape = x_shape [:- 1 ] + (self .out_len ,)
361
363
result = result .view (new_shape )
362
364
if self .bias is not None :
0 commit comments