Skip to content

Commit 736569d

Browse files
authored
[Platform] Custom ops support for LMhead and LogitsProcessor (vllm-project#23564)
Signed-off-by: zzhx1 <[email protected]>
1 parent 2eb9986 commit 736569d

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

vllm/model_executor/layers/logits_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from typing import Optional
77

88
import torch
9-
import torch.nn as nn
109

1110
import vllm.envs as envs
1211
from vllm.distributed import (tensor_model_parallel_all_gather,
1312
tensor_model_parallel_gather)
13+
from vllm.model_executor.custom_op import CustomOp
1414
from vllm.model_executor.layers.vocab_parallel_embedding import (
1515
VocabParallelEmbedding)
1616
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -22,7 +22,8 @@
2222
envs.VLLM_LOGITS_PROCESSOR_THREADS)
2323

2424

25-
class LogitsProcessor(nn.Module):
25+
@CustomOp.register("logits_processor")
26+
class LogitsProcessor(CustomOp):
2627
"""Process logits and apply logits processors from sampling metadata.
2728
2829
This layer does the following:

vllm/model_executor/layers/vocab_parallel_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def extra_repr(self) -> str:
429429
return s
430430

431431

432+
@CustomOp.register("parallel_lm_head")
432433
class ParallelLMHead(VocabParallelEmbedding):
433434
"""Parallelized LM head.
434435

0 commit comments

Comments
 (0)