Skip to content

Commit 6bddb1c

Browse files
JackWeiwyao-fengchenJackWeiw
authored
[Dlinfer][Ascend] Optimize performance of 310P device (#3486)
* support 310P * format code * fix accuracy of eager mode * update code * [dlinfer]fix tp for Ascend310P device * [dlinfer][ascend]lazy import torch_npu * [ascend]use safe device check * lint * lint * [dlinfer][ascend]convert linear weight to NZ at inital time * [ascend]fix tp2 lm compile transdata * [ascend]set transdata linear weight by default * [dlinfer][ascend]fix Transdata linear weight device check --------- Co-authored-by: yaofengchen <[email protected]> Co-authored-by: JackWeiw <[email protected]>
1 parent 8e0c15d commit 6bddb1c

File tree

3 files changed

+50
-9
lines changed

3 files changed

+50
-9
lines changed

lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf
4949
backend='atbgraph')
5050
else:
5151
self.model = torch.compile(self.model, fullgraph=True, dynamic=True, backend='atbgraph')
52+
if SocVersion.is_Ascend310P() and hasattr(self.model, 'get_logits'):
53+
# Compile get_logits for Ascend310P to use ATB linear since we would convert weight to nz format
54+
self.model.get_logits = torch.compile(self.model.get_logits,
55+
fullgraph=True,
56+
dynamic=True,
57+
backend='atbgraph')
5258

5359
def check_enable_graph(self):
5460
"""check enable graph."""

lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class AscendOpsBackend(DlinferOpsBackend):
9191
enable_graph = False
9292
half_negative_inf = torch.finfo(torch.float16).min
9393
total_slots = None
94+
# compiled atb Transdataoperation to convert tensor from ACL_FORMAT_ND to ACL_FORMAT_FRACTAL_NZ format.
95+
transdata_func = None
9496

9597
@staticmethod
9698
def get_name() -> str:
@@ -216,11 +218,13 @@ def get_total_slots():
216218
single_attention_mask = torch.triu(single_attention_mask, diagonal=1)
217219
attention_mask.append(single_attention_mask)
218220
else:
221+
# Transdata needs dtype to be float16 or int8
219222
single_attention_mask = torch.triu(
220-
torch.ones(max_q_seq_len, max_kv_seq_len).fill_(-float('inf')).cuda(),
223+
torch.ones(max_q_seq_len, max_kv_seq_len, dtype=torch.float16).fill_(-float('inf')).cuda(),
221224
diagonal=max_kv_seq_len - max_q_seq_len + 1,
222225
)
223-
attention_mask.append(single_attention_mask)
226+
# Convert to NZ format
227+
attention_mask.append(cls.get_transdata_func()(single_attention_mask, 2))
224228
else:
225229
raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.")
226230
else:
@@ -240,13 +244,21 @@ def get_total_slots():
240244
kv_seqlens = step_context.kv_seqlens.to(torch.int32)
241245
if not step_context.is_decoding:
242246
if is_unpaged_prefill:
243-
attention_mask = [mask.half() for mask in attention_mask]
244-
if SocVersion.is_Ascend310P():
245-
attention_mask = [torch.cat([mask.unsqueeze(0) for mask in attention_mask])]
247+
if SocVersion.is_Ascend910B():
248+
attention_mask = [mask.half() for mask in attention_mask]
246249
else:
247-
attention_mask = [
248-
torch.cat([mask.half() * cls.half_negative_inf for mask in attention_mask]).unsqueeze(1)
249-
]
250+
if SocVersion.is_Ascend910B():
251+
attention_mask = [
252+
torch.cat([mask.half() * cls.half_negative_inf for mask in attention_mask]).unsqueeze(1)
253+
]
254+
elif SocVersion.is_Ascend310P():
255+
# Convert mask to NZ format.
256+
attention_mask = [
257+
cls.get_transdata_func()(torch.cat(
258+
[mask.half() * cls.half_negative_inf for mask in attention_mask]).unsqueeze(1), 2)
259+
]
260+
else:
261+
raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.")
250262
kv_seqlens = kv_seqlens.repeat_interleave(step_context.q_seqlens, 0)
251263
else:
252264
if step_context.is_decoding:
@@ -302,6 +314,21 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_
302314
AscendOpsBackend.enable_graph = ascend_graph_runner.enable_graph
303315
return ascend_graph_runner
304316

317+
@staticmethod
318+
def get_transdata_func():
319+
"""get transdata function."""
320+
if AscendOpsBackend.transdata_func is None:
321+
import dlinfer
322+
from dlinfer.ops import transdata
323+
dlinfer.graph.config.enable_graph_mode = True
324+
if torch.distributed.is_initialized():
325+
torch._inductor.config.compile_threads = 1
326+
AscendOpsBackend.transdata_func = torch.compile(transdata,
327+
fullgraph=True,
328+
dynamic=False,
329+
backend='atbgraph')
330+
return AscendOpsBackend.transdata_func
331+
305332
@staticmethod
306333
def init():
307334
"""Initialize Ascend backend."""

lmdeploy/pytorch/backends/dlinfer/linear.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,16 @@ class DlinferLinearImpl(LinearImpl):
1515

1616
def update_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
1717
"""update weights."""
18-
if os.getenv('DLINER_LINEAR_USE_NN_LAYOUT', '0') == '1':
18+
if os.getenv('DLINFER_LINEAR_USE_NN_LAYOUT', '0') == '1':
1919
weight = weight.data.t().contiguous()
20+
if weight.device.type == 'npu':
21+
from .ascend import SocVersion
22+
if SocVersion.is_Ascend310P() and not os.getenv('DLINFER_DISABLE_LINEAR_NZ_FORMAT', '0') == '1':
23+
# Ascend 310P device need weight to be NZ format, so Transdata it initially.
24+
# Transdata Linear weight by default, if Error occurs, please set
25+
# DLINFER_DISABLE_LINEAR_NZ_FORMAT=1 to disable transdata.
26+
from .ascend import AscendOpsBackend
27+
weight = AscendOpsBackend.get_transdata_func()(weight, 2)
2028
return weight, bias
2129

2230
def forward(self,

0 commit comments

Comments
 (0)