Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ data/
output_models
adapter_model/

# output data
output_data/

# Distribution / packaging
.Python
build/
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ bash ./scripts/run_chatbot.sh output_models/finetuned_gpt2
>```bash
>bash ./scripts/run_sglang_inference.sh
>```
>
> Note: If you encounter error ModuleNotFoundError: No module named 'common_ops' when using SGLang, please try `apt-get update` and then `apt install numactl`.
> </details>

### Deployment
Expand Down
4 changes: 2 additions & 2 deletions examples/rm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def main():
dataset,
)

if pipeline_args.save_results:
res.save(pipeline_args.results_path)
if pipeline_args.save_inference_results:
res.save(pipeline_args.inference_results_path)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ evaluate==0.4.0
bitsandbytes>=0.40.0
pydantic
accelerate>=0.27.2
einops>=0.6.1
einops>=0.6.1
tensordict
4 changes: 2 additions & 2 deletions scripts/archive/run_rm_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ accelerate launch --config_file configs/accelerator_multigpu_config.yaml \
--overwrite_cache True \
--conversation_template ${conversation_template} \
--preprocessing_num_workers 16 \
--save_results True \
--results_path ${output_file_path} \
--save_inference_results True \
--inference_results_path ${output_file_path} \
2>&1 | tee ${log_dir}/rm_inference.log
4 changes: 2 additions & 2 deletions scripts/archive/run_vllm_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ python examples/vllm_inference.py \
--temperature 1.0 \
--top_p 0.9 \
--max_new_tokens 1024 \
--save_results True \
--results_path ${output_file_path} \
--save_inference_results True \
--inference_results_path ${output_file_path} \
--enable_decode_inference_result False \
--vllm_gpu_memory_utilization 0.95 \
--vllm_tensor_parallel_size 2 \
Expand Down
8 changes: 3 additions & 5 deletions scripts/run_sglang_inference.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
python examples/sglang_inference.py \
--model_name_or_path Qwen/Qwen3-4B-Instruct-2507 \
--dataset_path data/alpaca/test_conversation \
--output_dir output_data/sglang_inference_results \
--output_file_name results.json \
--dataset_path data/alpaca/prompt_only \
--inference_engine sglang \
--inference_gpu_memory_utilization 0.8 \
--num_output_sequences 2 \
--temperature 1.0 \
--max_new_tokens 2048 \
--top_p 0.95 \
--random_seed 42 \
--save_results True \
--results_path output_data/sglang_inference_results/results.json
--save_inference_results True \
--inference_results_path output_data/sglang_inference_results/results.json
28 changes: 20 additions & 8 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,8 +1053,11 @@ class InferencerArguments:
)

# Args for result saving
save_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."})
results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."})
save_results: Optional[bool] = field(default=None, metadata={"help": "Whether to save results."})
results_path: Optional[str] = field(default=None, metadata={"help": "The path of results."})

save_inference_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."})
inference_results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."})

def __post_init__(self):
if self.use_accelerator is not None:
Expand All @@ -1063,15 +1066,23 @@ def __post_init__(self):
"It will not take effect and will be removed in a future version, "
"since LMFlow now can automatically detect whether is in Accelerate or Deepspeed environment."
)

if self.save_results:
if self.results_path is None:
raise ValueError("Need to specify results_path when save_results is True.")
logger.warning("`save_results` is deprecated and will be removed in a future version. Please use `save_inference_results` instead.")
self.save_inference_results = self.save_results

if self.results_path:
logger.warning("`results_path` is deprecated and will be removed in a future version. Please use `inference_results_path` instead.")
self.inference_results_path = self.results_path

if self.save_inference_results:
if self.inference_results_path is None:
raise ValueError("Need to specify inference_results_path when save_inference_results is True.")
else:
if not self.results_path.endswith(".json"):
raise ValueError("The results_path must be a json file.")
if not self.inference_results_path.endswith(".json"):
raise ValueError("The inference_results_path must be a json file.")
else:
Path(self.results_path).parent.mkdir(parents=True, exist_ok=True)
Path(self.inference_results_path).parent.mkdir(parents=True, exist_ok=True)

if self.use_vllm is True:
logger.warning(
Expand Down Expand Up @@ -1352,6 +1363,7 @@ class IterativeDPOAlignerArguments(IterativeAlignerArguments, DPOv2AlignerArgume
"evaluator": EvaluatorArguments,
"inferencer": InferencerArguments,
"vllm_inferencer": InferencerArguments,
"sglang_inferencer": InferencerArguments,
"rm_inferencer": InferencerArguments,
"raft_aligner": RaftAlignerArguments,
"dpo_aligner": DPOAlignerArguments,
Expand Down
49 changes: 30 additions & 19 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from typing import Literal, Optional, Union

import numpy as np
import torch
from peft import PeftModel
from transformers import (
Expand All @@ -40,6 +41,7 @@
from lmflow.utils.deprecated import deprecated_args
from lmflow.utils.envs import is_accelerate_env
from lmflow.utils.versioning import is_flash_attn_available, is_ray_available, is_vllm_available
from lmflow.utils.protocol import DataProto

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -286,7 +288,7 @@ def decode(self, input, **kwargs) -> Union[str, list[str]]:
)
def inference(
self,
inputs: Union[str, list[str], torch.Tensor],
inputs: Union[str, list[str], torch.Tensor, DataProto],
sampling_params: Optional[Union[dict, "SamplingParams"]] = None,
return_logprob: bool = False,
release_gpu: bool = False,
Expand All @@ -296,16 +298,17 @@ def inference(
enable_deterministic_inference: bool = False,
attention_backend: Optional[str] = None,
**kwargs,
):
) -> Union[list[VLLMInferenceResultWithInput] | DataProto]:
"""
Perform generation process of the model.

Parameters
------------
inputs : Union[str, list[str], torch.Tensor]
inputs : Union[str, list[str], torch.Tensor, DataProto]
The sequence used as a prompt for the generation or as model inputs to the model.
When the inference engine is "vllm" or "sglang", this should be a string or a list of strings.
When the inference engine is "vllm", this should be a string or a list of strings.
When the inference engine is "huggingface", this should be a tensor.
When the inference engine is "sglang", this should be a DataProto.
sampling_params : Optional[Union[dict, "SamplingParams"]], optional
The sampling parameters to use, by default None.
return_logprob : bool, optional
Expand Down Expand Up @@ -345,7 +348,6 @@ def inference(
elif inference_engine == "sglang":
res = self.__sglang_inference(
inputs=inputs,
sampling_params=sampling_params,
return_logprob=return_logprob,
)
else:
Expand Down Expand Up @@ -439,21 +441,18 @@ def __vllm_inference(

def __sglang_inference(
self,
inputs: list[str],
sampling_params: Optional[dict] = None,
inputs: DataProto,
return_logprob: bool = False,
):
) -> DataProto:
"""Perform SGLang inference process of the model."""
sglang_outputs = self.backend_model_for_inference.generate(
prompt=inputs,
sampling_params=sampling_params,
prompt=inputs.non_tensor_batch["inputs"].tolist(), # use tensor instead of str later
sampling_params=inputs.meta_info["sampling_params"],
return_logprob=return_logprob,
)
# TODO: unified lmflow sample format
for idx, output in enumerate(sglang_outputs):
output["input"] = inputs[idx]
output["output"] = output.pop("text")
return sglang_outputs
inputs.non_tensor_batch["outputs"] = [output["text"] for output in sglang_outputs]
# TODO: padding for batching the output ids; generatin details
return inputs

@deprecated_args(
use_vllm={
Expand All @@ -471,7 +470,8 @@ def prepare_inputs_for_inference(
apply_chat_template: bool = True,
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
enable_distributed_inference: bool = False,
) -> Union[list[str], "ray.data.Dataset"]:
sampling_params: Optional[dict] = None,
) -> Union[list[str], "ray.data.Dataset", DataProto]:
if dataset.get_type() == "text_only":
if apply_chat_template:
dataset = dataset.map(
Expand Down Expand Up @@ -551,9 +551,20 @@ def preprocess_conversation(sample):
inference_inputs
) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])}

if inference_engine == "sglang" and self.tokenizer.bos_token:
# in consistent with sglang bench_serving.py demo
inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs]
if inference_engine == "sglang":
if self.tokenizer.bos_token:
# in consistent with sglang bench_serving.py demo
inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs]

# currently only test dataproto on sglang inference
inference_inputs = np.array(inference_inputs)
inference_inputs = DataProto.from_single_dict(
data={"inputs": inference_inputs},
meta_info={"sampling_params": {**sampling_params, "n": 1}, "actual_n_rollouts": sampling_params["n"]}
)

# handling n>1 since we don't want one-to-many mapping. Later this will be applied to all inference engines.
inference_inputs = inference_inputs.repeat(sampling_params["n"])

return inference_inputs

Expand Down
33 changes: 15 additions & 18 deletions src/lmflow/pipeline/sglang_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.pipeline.base_pipeline import BasePipeline
from lmflow.utils.versioning import is_sglang_available
from lmflow.utils.protocol import DataProto

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,7 +65,7 @@ def inference(
dataset: Dataset,
release_gpu: bool = False,
inference_args: Optional[InferencerArguments] = None,
):
) -> DataProto:
if inference_args:
logger.warning("Overriding the default inference arguments with the provided arguments in .inference()")
sampling_params = self._parse_args_to_sampling_params(inference_args)
Expand All @@ -76,13 +77,11 @@ def inference(
dataset=dataset,
apply_chat_template=self.inferencer_args.apply_chat_template,
inference_engine="sglang",
sampling_params=sampling_params,
)
# handling n>1 since we don't want one-to-many mapping
model_input = [sample for sample in model_input for _ in range(sampling_params["n"])]

outputs = model.inference(
inputs=model_input,
sampling_params=sampling_params.copy().update({"n": 1}),
return_logprob=self.inferencer_args.return_logprob,
release_gpu=release_gpu,
inference_engine="sglang",
Expand All @@ -92,26 +91,24 @@ def inference(
attention_backend=self.inferencer_args.attention_backend,
)

if self.inferencer_args.save_results:
self.save_inference_results(outputs, self.inferencer_args.results_path)
if self.inferencer_args.save_inference_results:
self.save_inference_results(outputs, self.inferencer_args.inference_results_path)

return outputs

def save_inference_results(
self,
outputs: Union[list[list[str]], list[list[list[int]]]],
save_file_path: str,
outputs: DataProto,
inference_results_path: str,
):
with open(save_file_path, "w", encoding="utf-8") as f:
json.dump(outputs, f, ensure_ascii=False, indent=4)

logger.info(f"Inference results are saved to {save_file_path}.")
if not inference_results_path.endswith(".pkl"):
logger.warning(f"The inference results path must be a pickle file. Change the path to {inference_results_path}.pkl")
inference_results_path = inference_results_path + ".pkl"
outputs.save_to_disk(inference_results_path)
logger.info(f"Inference results are saved to {inference_results_path}.")

def load_inference_results(
self,
results_path: str,
) -> Union[list[list[str]], list[list[list[int]]]]:
with open(results_path) as f:
results = json.load(f)

return results
inference_results_path: str,
) -> DataProto:
return DataProto.load_from_disk(inference_results_path)
35 changes: 35 additions & 0 deletions src/lmflow/utils/envs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,43 @@
"""
ref: https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py
"""

import os
import logging

import torch


logger = logging.getLogger(__name__)
is_cuda_available = torch.cuda.is_available()


def is_accelerate_env():
for key, _ in os.environ.items():
if key.startswith("ACCELERATE_"):
return True
return False


def get_device_name() -> str:
"""
Get the device name based on the current machine.
"""
if is_cuda_available:
device = "cuda"
else:
device = "cpu"
return device


def get_torch_device() -> any:
"""Return the corresponding torch attribute based on the device type string.
Returns:
module: The corresponding torch device namespace, or torch.cuda if not found.
"""
device_name = get_device_name()
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
return torch.cuda
Loading