Skip to content

Commit 4c4f8d1

Browse files
authored
[LLM]Fix Arc falcon abnormal output issue (intel#9096)
* update * update * fix error & style * fix style * update train * to input_seq_size
1 parent 548e4dd commit 4c4f8d1

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

python/llm/src/bigdl/llm/transformers/low_bit_linear.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,10 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor,
288288
class MatMulLowBit(torch.autograd.Function):
289289

290290
@staticmethod
291-
def forward(ctx, A, weight):
291+
def forward(ctx, A, weight, input_seq_size):
292292
ctx.is_empty = False
293293
import linear_q4_0
294-
result = linear_q4_0.forward_new(A, weight.data, weight.qtype)
294+
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
295295
if any(ctx.needs_input_grad[:2]):
296296
ctx.tensors = (A, weight)
297297
else:
@@ -304,14 +304,14 @@ def backward(ctx, grad_output):
304304
if ctx.is_empty:
305305
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
306306
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
307-
req_gradA, _ = ctx.needs_input_grad
307+
req_gradA, _, _ = ctx.needs_input_grad
308308
A, weight = ctx.tensors
309309
grad_A, grad_weight = None, None
310310
if req_gradA:
311311
dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype)
312312
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
313313

314-
return grad_A, grad_weight
314+
return grad_A, grad_weight, None
315315

316316

317317
class LowBitLinear(nn.Linear):
@@ -353,10 +353,12 @@ def forward(self, x: torch.Tensor):
353353
# disable the conversion when training
354354
if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32:
355355
x_2d = x_2d.half()
356+
input_seq_size = x_shape[1]
356357
if self.training and x_2d.requires_grad:
357-
result = MatMulLowBit.apply(x_2d, self.weight)
358+
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
358359
else:
359-
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype)
360+
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
361+
input_seq_size)
360362
new_shape = x_shape[:-1] + (self.out_len,)
361363
result = result.view(new_shape)
362364
if self.bias is not None:

0 commit comments

Comments
 (0)