diff --git a/.gitignore b/.gitignore index 1e2c5f31d..eb1e14451 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,9 @@ data/ output_models adapter_model/ +# output data +output_data/ + # Distribution / packaging .Python build/ diff --git a/README.md b/README.md index c8b52f22d..a02fd653c 100644 --- a/README.md +++ b/README.md @@ -290,7 +290,7 @@ bash ./scripts/run_chatbot.sh output_models/finetuned_gpt2 >```bash >bash ./scripts/run_sglang_inference.sh >``` -> +> Note: If you encounter error ModuleNotFoundError: No module named 'common_ops' when using SGLang, please try `apt-get update` and then `apt install numactl`. > ### Deployment diff --git a/examples/rm_inference.py b/examples/rm_inference.py index 7de1e727e..66263a361 100644 --- a/examples/rm_inference.py +++ b/examples/rm_inference.py @@ -42,8 +42,8 @@ def main(): dataset, ) - if pipeline_args.save_results: - res.save(pipeline_args.results_path) + if pipeline_args.save_inference_results: + res.save(pipeline_args.inference_results_path) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index caaaa3cb2..ec7c0e0d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ evaluate==0.4.0 bitsandbytes>=0.40.0 pydantic accelerate>=0.27.2 -einops>=0.6.1 \ No newline at end of file +einops>=0.6.1 +tensordict \ No newline at end of file diff --git a/scripts/archive/run_rm_inference.sh b/scripts/archive/run_rm_inference.sh index 32d68dff1..ccce4c34c 100644 --- a/scripts/archive/run_rm_inference.sh +++ b/scripts/archive/run_rm_inference.sh @@ -67,6 +67,6 @@ accelerate launch --config_file configs/accelerator_multigpu_config.yaml \ --overwrite_cache True \ --conversation_template ${conversation_template} \ --preprocessing_num_workers 16 \ - --save_results True \ - --results_path ${output_file_path} \ + --save_inference_results True \ + --inference_results_path ${output_file_path} \ 2>&1 | tee ${log_dir}/rm_inference.log \ No newline at end of file diff --git a/scripts/archive/run_vllm_inference.sh b/scripts/archive/run_vllm_inference.sh index 8e9598dc3..de5f7ab99 100644 --- a/scripts/archive/run_vllm_inference.sh +++ b/scripts/archive/run_vllm_inference.sh @@ -69,8 +69,8 @@ python examples/vllm_inference.py \ --temperature 1.0 \ --top_p 0.9 \ --max_new_tokens 1024 \ - --save_results True \ - --results_path ${output_file_path} \ + --save_inference_results True \ + --inference_results_path ${output_file_path} \ --enable_decode_inference_result False \ --vllm_gpu_memory_utilization 0.95 \ --vllm_tensor_parallel_size 2 \ diff --git a/scripts/run_sglang_inference.sh b/scripts/run_sglang_inference.sh index 5bf259121..c89c3c86e 100644 --- a/scripts/run_sglang_inference.sh +++ b/scripts/run_sglang_inference.sh @@ -1,8 +1,6 @@ python examples/sglang_inference.py \ --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 \ - --dataset_path data/alpaca/test_conversation \ - --output_dir output_data/sglang_inference_results \ - --output_file_name results.json \ + --dataset_path data/alpaca/prompt_only \ --inference_engine sglang \ --inference_gpu_memory_utilization 0.8 \ --num_output_sequences 2 \ @@ -10,5 +8,5 @@ python examples/sglang_inference.py \ --max_new_tokens 2048 \ --top_p 0.95 \ --random_seed 42 \ - --save_results True \ - --results_path output_data/sglang_inference_results/results.json \ No newline at end of file + --save_inference_results True \ + --inference_results_path output_data/sglang_inference_results/results.json \ No newline at end of file diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 162f85b8a..5b09c075a 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -1053,8 +1053,11 @@ class InferencerArguments: ) # Args for result saving - save_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."}) - results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."}) + save_results: Optional[bool] = field(default=None, metadata={"help": "Whether to save results."}) + results_path: Optional[str] = field(default=None, metadata={"help": "The path of results."}) + + save_inference_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."}) + inference_results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."}) def __post_init__(self): if self.use_accelerator is not None: @@ -1063,15 +1066,23 @@ def __post_init__(self): "It will not take effect and will be removed in a future version, " "since LMFlow now can automatically detect whether is in Accelerate or Deepspeed environment." ) - + if self.save_results: - if self.results_path is None: - raise ValueError("Need to specify results_path when save_results is True.") + logger.warning("`save_results` is deprecated and will be removed in a future version. Please use `save_inference_results` instead.") + self.save_inference_results = self.save_results + + if self.results_path: + logger.warning("`results_path` is deprecated and will be removed in a future version. Please use `inference_results_path` instead.") + self.inference_results_path = self.results_path + + if self.save_inference_results: + if self.inference_results_path is None: + raise ValueError("Need to specify inference_results_path when save_inference_results is True.") else: - if not self.results_path.endswith(".json"): - raise ValueError("The results_path must be a json file.") + if not self.inference_results_path.endswith(".json"): + raise ValueError("The inference_results_path must be a json file.") else: - Path(self.results_path).parent.mkdir(parents=True, exist_ok=True) + Path(self.inference_results_path).parent.mkdir(parents=True, exist_ok=True) if self.use_vllm is True: logger.warning( @@ -1352,6 +1363,7 @@ class IterativeDPOAlignerArguments(IterativeAlignerArguments, DPOv2AlignerArgume "evaluator": EvaluatorArguments, "inferencer": InferencerArguments, "vllm_inferencer": InferencerArguments, + "sglang_inferencer": InferencerArguments, "rm_inferencer": InferencerArguments, "raft_aligner": RaftAlignerArguments, "dpo_aligner": DPOAlignerArguments, diff --git a/src/lmflow/models/hf_decoder_model.py b/src/lmflow/models/hf_decoder_model.py index 3c105b4e6..bb899fd04 100644 --- a/src/lmflow/models/hf_decoder_model.py +++ b/src/lmflow/models/hf_decoder_model.py @@ -18,6 +18,7 @@ import os from typing import Literal, Optional, Union +import numpy as np import torch from peft import PeftModel from transformers import ( @@ -40,6 +41,7 @@ from lmflow.utils.deprecated import deprecated_args from lmflow.utils.envs import is_accelerate_env from lmflow.utils.versioning import is_flash_attn_available, is_ray_available, is_vllm_available +from lmflow.utils.protocol import DataProto logger = logging.getLogger(__name__) @@ -286,7 +288,7 @@ def decode(self, input, **kwargs) -> Union[str, list[str]]: ) def inference( self, - inputs: Union[str, list[str], torch.Tensor], + inputs: Union[str, list[str], torch.Tensor, DataProto], sampling_params: Optional[Union[dict, "SamplingParams"]] = None, return_logprob: bool = False, release_gpu: bool = False, @@ -296,16 +298,17 @@ def inference( enable_deterministic_inference: bool = False, attention_backend: Optional[str] = None, **kwargs, - ): + ) -> Union[list[VLLMInferenceResultWithInput] | DataProto]: """ Perform generation process of the model. Parameters ------------ - inputs : Union[str, list[str], torch.Tensor] + inputs : Union[str, list[str], torch.Tensor, DataProto] The sequence used as a prompt for the generation or as model inputs to the model. - When the inference engine is "vllm" or "sglang", this should be a string or a list of strings. + When the inference engine is "vllm", this should be a string or a list of strings. When the inference engine is "huggingface", this should be a tensor. + When the inference engine is "sglang", this should be a DataProto. sampling_params : Optional[Union[dict, "SamplingParams"]], optional The sampling parameters to use, by default None. return_logprob : bool, optional @@ -345,7 +348,6 @@ def inference( elif inference_engine == "sglang": res = self.__sglang_inference( inputs=inputs, - sampling_params=sampling_params, return_logprob=return_logprob, ) else: @@ -439,21 +441,18 @@ def __vllm_inference( def __sglang_inference( self, - inputs: list[str], - sampling_params: Optional[dict] = None, + inputs: DataProto, return_logprob: bool = False, - ): + ) -> DataProto: """Perform SGLang inference process of the model.""" sglang_outputs = self.backend_model_for_inference.generate( - prompt=inputs, - sampling_params=sampling_params, + prompt=inputs.non_tensor_batch["inputs"].tolist(), # use tensor instead of str later + sampling_params=inputs.meta_info["sampling_params"], return_logprob=return_logprob, ) - # TODO: unified lmflow sample format - for idx, output in enumerate(sglang_outputs): - output["input"] = inputs[idx] - output["output"] = output.pop("text") - return sglang_outputs + inputs.non_tensor_batch["outputs"] = [output["text"] for output in sglang_outputs] + # TODO: padding for batching the output ids; generatin details + return inputs @deprecated_args( use_vllm={ @@ -471,7 +470,8 @@ def prepare_inputs_for_inference( apply_chat_template: bool = True, inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", enable_distributed_inference: bool = False, - ) -> Union[list[str], "ray.data.Dataset"]: + sampling_params: Optional[dict] = None, + ) -> Union[list[str], "ray.data.Dataset", DataProto]: if dataset.get_type() == "text_only": if apply_chat_template: dataset = dataset.map( @@ -551,9 +551,20 @@ def preprocess_conversation(sample): inference_inputs ) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])} - if inference_engine == "sglang" and self.tokenizer.bos_token: - # in consistent with sglang bench_serving.py demo - inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs] + if inference_engine == "sglang": + if self.tokenizer.bos_token: + # in consistent with sglang bench_serving.py demo + inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs] + + # currently only test dataproto on sglang inference + inference_inputs = np.array(inference_inputs) + inference_inputs = DataProto.from_single_dict( + data={"inputs": inference_inputs}, + meta_info={"sampling_params": {**sampling_params, "n": 1}, "actual_n_rollouts": sampling_params["n"]} + ) + + # handling n>1 since we don't want one-to-many mapping. Later this will be applied to all inference engines. + inference_inputs = inference_inputs.repeat(sampling_params["n"]) return inference_inputs diff --git a/src/lmflow/pipeline/sglang_inferencer.py b/src/lmflow/pipeline/sglang_inferencer.py index c2380cb9e..59bcfcce0 100644 --- a/src/lmflow/pipeline/sglang_inferencer.py +++ b/src/lmflow/pipeline/sglang_inferencer.py @@ -15,6 +15,7 @@ from lmflow.models.hf_decoder_model import HFDecoderModel from lmflow.pipeline.base_pipeline import BasePipeline from lmflow.utils.versioning import is_sglang_available +from lmflow.utils.protocol import DataProto logger = logging.getLogger(__name__) @@ -64,7 +65,7 @@ def inference( dataset: Dataset, release_gpu: bool = False, inference_args: Optional[InferencerArguments] = None, - ): + ) -> DataProto: if inference_args: logger.warning("Overriding the default inference arguments with the provided arguments in .inference()") sampling_params = self._parse_args_to_sampling_params(inference_args) @@ -76,13 +77,11 @@ def inference( dataset=dataset, apply_chat_template=self.inferencer_args.apply_chat_template, inference_engine="sglang", + sampling_params=sampling_params, ) - # handling n>1 since we don't want one-to-many mapping - model_input = [sample for sample in model_input for _ in range(sampling_params["n"])] outputs = model.inference( inputs=model_input, - sampling_params=sampling_params.copy().update({"n": 1}), return_logprob=self.inferencer_args.return_logprob, release_gpu=release_gpu, inference_engine="sglang", @@ -92,26 +91,24 @@ def inference( attention_backend=self.inferencer_args.attention_backend, ) - if self.inferencer_args.save_results: - self.save_inference_results(outputs, self.inferencer_args.results_path) + if self.inferencer_args.save_inference_results: + self.save_inference_results(outputs, self.inferencer_args.inference_results_path) return outputs def save_inference_results( self, - outputs: Union[list[list[str]], list[list[list[int]]]], - save_file_path: str, + outputs: DataProto, + inference_results_path: str, ): - with open(save_file_path, "w", encoding="utf-8") as f: - json.dump(outputs, f, ensure_ascii=False, indent=4) - - logger.info(f"Inference results are saved to {save_file_path}.") + if not inference_results_path.endswith(".pkl"): + logger.warning(f"The inference results path must be a pickle file. Change the path to {inference_results_path}.pkl") + inference_results_path = inference_results_path + ".pkl" + outputs.save_to_disk(inference_results_path) + logger.info(f"Inference results are saved to {inference_results_path}.") def load_inference_results( self, - results_path: str, - ) -> Union[list[list[str]], list[list[list[int]]]]: - with open(results_path) as f: - results = json.load(f) - - return results + inference_results_path: str, + ) -> DataProto: + return DataProto.load_from_disk(inference_results_path) diff --git a/src/lmflow/utils/envs.py b/src/lmflow/utils/envs.py index 1740d154a..ff1a40d8c 100644 --- a/src/lmflow/utils/envs.py +++ b/src/lmflow/utils/envs.py @@ -1,4 +1,15 @@ +""" +ref: https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py +""" + import os +import logging + +import torch + + +logger = logging.getLogger(__name__) +is_cuda_available = torch.cuda.is_available() def is_accelerate_env(): @@ -6,3 +17,27 @@ def is_accelerate_env(): if key.startswith("ACCELERATE_"): return True return False + + +def get_device_name() -> str: + """ + Get the device name based on the current machine. + """ + if is_cuda_available: + device = "cuda" + else: + device = "cpu" + return device + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda \ No newline at end of file diff --git a/src/lmflow/utils/protocol.py b/src/lmflow/utils/protocol.py new file mode 100644 index 000000000..c8779df2d --- /dev/null +++ b/src/lmflow/utils/protocol.py @@ -0,0 +1,1067 @@ +""" +ref: https://github.com/volcengine/verl/blob/main/verl/protocol.py +Implement base data transfer protocol between any two functions, modules. +We can subclass Protocol to define more detailed batch info with specific keys +""" + +import contextlib +import copy +import logging +import math +import pickle +from dataclasses import dataclass, field +from typing import Any, Optional + +import numpy as np +import tensordict +import torch +from packaging import version +from packaging.version import parse as parse_version +from tensordict import TensorDict +from tensordict.tensorclass import NonTensorData, NonTensorStack +from torch.utils.data import DataLoader + +from lmflow.utils.envs import get_torch_device + +logger = logging.getLogger(__name__) + +with contextlib.suppress(Exception): + tensordict.set_lazy_legacy(False).set() + if parse_version(tensordict.__version__) < parse_version("0.10.0"): + tensordict.set_list_to_stack(True).set() + + +def union_python_dict(dict1: dict, dict2: dict): + """Union two dict. Will throw an error if there is an item not the same object with the same key. + + Args: + dict1: + dict2: + + Returns: + + """ + for key, val in dict2.items(): + if key in dict1: + assert dict2[key] == dict1[key], f"{key} in meta_dict1 and meta_dict2 are not the same object" + dict1[key] = val + + return dict1 + + +def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: + """Union two tensordicts.""" + assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( + f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" + ) + for key in tensor_dict2.keys(): + if key not in tensor_dict1.keys(): + tensor_dict1[key] = tensor_dict2[key] + else: + assert tensor_dict1[key].equal(tensor_dict2[key]), ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) + + return tensor_dict1 + + +def _array_equal(array1: np.ndarray, array2: np.ndarray, visited: set[int]) -> bool: + """ + Recursively compares two NumPy arrays for strict equality, with special + handling for object-dtype arrays, NaN values, and circular references. + This function assumes that the two arguments provided are NumPy arrays. + + Args: + array1: The first NumPy array. + array2: The second NumPy array. + + Returns: + True if the arrays' dtypes, shapes, and all elements are equal. + """ + # Check dtype and shape first, as this is the fastest failure path. + if array1.dtype != array2.dtype or array1.shape != array2.shape: + return False + + # For non-object dtypes, use NumPy's implementation with equal_nan=True. + if array1.dtype != "object": + return np.array_equal(array1, array2, equal_nan=True) + + # For object-dtype arrays, we must recursively compare each element. + # We delegate to _deep_equal to handle elements, as they could be any + # type, including other nested arrays or NaNs. + return all(_deep_equal(x, y, visited) for x, y in zip(array1.flat, array2.flat, strict=False)) + + +def _deep_equal(a: Any, b: Any, visited: set[int]) -> bool: + """ + Recursively performs a deep comparison between two Python objects. + - Handles NaN values correctly (NaN == NaN evaluates to True). + - Handling circular references. + - Dispatches to _array_equal if both objects are NumPy arrays. + - Otherwise, uses standard '==' comparison. + """ + if type(a) is not type(b): + return False + + # If we have seen this object ID before on this path, it's a cycle. + # Since we already know the types match, we can safely assume this part + # of the structure is equal. + obj_id = id(a) + if obj_id in visited: + return True + + visited.add(obj_id) + + # Perform the specific comparison based on type + result = False + if isinstance(a, float) and math.isnan(a) and math.isnan(b): + result = True + elif isinstance(a, np.ndarray): + # We know b is also an ndarray due to the initial type check + result = _array_equal(a, b, visited) + else: + # Standard equality for all other types + result = a == b + + # Clean up the visited set on the way out of the recursion + visited.remove(obj_id) + return result + + +def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]: + for key, val in tensor_dict2.items(): + if key in tensor_dict1: + assert isinstance(tensor_dict2[key], np.ndarray) + assert isinstance(tensor_dict1[key], np.ndarray) + # to properly deal with nan and object type + assert _deep_equal(tensor_dict1[key], tensor_dict2[key], visited=set()), ( + f"`{key}` in tensor_dict1 and tensor_dict2 are not the same object." + ) + tensor_dict1[key] = val + + return tensor_dict1 + + +def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): + if len(list_of_dict) == 0: + return {} + keys = list_of_dict[0].keys() + output = {key: [] for key in keys} + for data in list_of_dict: + for key, item in data.items(): + assert key in output + output[key].append(item) + return output + + +def collate_fn(x: list["DataProtoItem"]): + batch = [] + non_tensor_batch = [] + for data in x: + batch.append(data.batch) + non_tensor_batch.append(data.non_tensor_batch) + batch = torch.stack(batch).contiguous() + non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch) + for key, val in non_tensor_batch.items(): + non_tensor_batch[key] = np.array(val, dtype=object) + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) + + +def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict: + """Create a TensorDict from tensors and non-tensor data. + + Automatically handles nested structures in lists by converting them to NonTensorStack. + This enables support for: + - Lists of lists: [[], [0.5, 0.8], [0.9]] + - Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}] + - Lists of lists of dicts: [[{"content": "...", "role": "user"}]] + + Args: + tensor_dict: Dictionary of tensors and lists to include in the TensorDict + non_tensor_dict: Dictionary of metadata to store as NonTensorData + + Returns: + TensorDict with proper handling of nested structures + + Example: + >>> td = get_tensordict( + ... tensor_dict={ + ... "obs": torch.randn(3, 4), + ... "turn_scores": [[], [0.5, 0.8], [0.9]] # Nested list + ... }, + ... non_tensor_dict={"experiment": "test"} + ... ) + """ + tensor_dict = tensor_dict.copy() + if non_tensor_dict is None: + non_tensor_dict = {} + + batch_size = None + + for key, val in tensor_dict.items(): + if isinstance(val, torch.Tensor) and val.is_nested: + assert val.is_contiguous(), "Nested tensors must be contiguous. Try setting layout=torch.jagged" + assert val.layout == torch.jagged, "Nested tensors must be jagged." + + # Skip validation for NonTensorStack as it's already properly formatted + if isinstance(val, NonTensorStack): + if batch_size is None: + batch_size = len(val) + else: + assert len(val) == batch_size, ( + f"Batch size of NonTensorStack {key} is not consistent with other tensors. " + f"Expected {batch_size}, got {len(val)}" + ) + continue + + if isinstance(val, list): + for v in val: + assert not isinstance(v, torch.Tensor), ( + "Passing a list makes the data NonTensorStack, " + "which doesn't support torch.Tensor. Please convert to numpy first" + ) + # Convert to NonTensorStack to handle nested structures + tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val]) + + assert isinstance(val, torch.Tensor | list) + + if batch_size is None: + batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val) + else: + val_batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val) + assert val_batch_size == batch_size, ( + f"Batch size of tensor {key} is not consistent with other tensors. " + f"Expected {batch_size}, got {val_batch_size}" + ) + + if batch_size is None: + batch_size = [] + else: + batch_size = [batch_size] + + for key, val in non_tensor_dict.items(): + assert key not in tensor_dict + tensor_dict[key] = NonTensorData(val) + + return TensorDict(source=tensor_dict, batch_size=batch_size) + + +@dataclass +class DataProtoItem: + batch: TensorDict = None + non_tensor_batch: dict = field(default_factory=dict) + meta_info: dict = field(default_factory=dict) + + +@dataclass +class DataProto: + """ + A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. + It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. + TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the + same batch size should be put inside batch. + """ + + batch: TensorDict = None + non_tensor_batch: dict = field(default_factory=dict) + meta_info: dict = field(default_factory=dict) + + def __post_init__(self): + # perform necessary checking + self.check_consistency() + + def __len__(self): + if self.batch is not None: + return self.batch.batch_size[0] + elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0: + random_key = list(self.non_tensor_batch.keys())[0] + return self.non_tensor_batch[random_key].shape[0] + else: + return 0 + + def __getitem__(self, item): + """ + Enhanced indexing for DataProto objects. + + Args: + item: Can be one of: + - int: A single index + - slice: A slice object (start:stop:step) + - list: A list of indices + - numpy.ndarray: An array of indices + - torch.Tensor: A tensor of indices + + Returns: + DataProto: For all indexing types except single integers + DataProtoItem: Only for single integer indices + """ + # Case 1: Slice object - use the slice method + if isinstance(item, slice): + return self.slice(item.start, item.stop, item.step) + + # Case 2: List, numpy array, or torch tensor - use sel_idxs + elif isinstance(item, list | np.ndarray | torch.Tensor): + return self.select_idxs(item) + + # Case 3: Single integer - return DataProtoItem for backward compatibility + elif isinstance(item, int | np.integer): + tensor_data = self.batch[item] if self.batch is not None else None + non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} + return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) + + # # Case 4: Unsupported type + else: + raise TypeError(f"Indexing with {type(item)} is not supported") + + def __getstate__(self): + import io + + buffer = io.BytesIO() + if tensordict.__version__ >= "0.5.0" and self.batch is not None: + self.batch = self.batch.contiguous() + self.batch = self.batch.consolidate() + torch.save(self.batch, buffer) + return buffer, self.non_tensor_batch, self.meta_info + + def __setstate__(self, data): + batch_deserialized, non_tensor_batch, meta_info = data + batch_deserialized.seek(0) + batch = torch.load( + batch_deserialized, weights_only=False, map_location="cpu" if not get_torch_device().is_available() else None, + ) + self.batch = batch + self.non_tensor_batch = non_tensor_batch + self.meta_info = meta_info + + def save_to_disk(self, filepath): + with open(filepath, "wb") as f: + pickle.dump(self, f) + + @staticmethod + def load_from_disk(filepath) -> "DataProto": + with open(filepath, "rb") as f: + data = pickle.load(f) + return data + + def print_size(self, prefix=""): + size_of_tensordict = 0 + if self.batch is not None: + for _, tensor in self.batch.items(): + size_of_tensordict += tensor.element_size() * tensor.numel() + size_of_numpy_array = 0 + for _, numpy_array in self.non_tensor_batch.items(): + size_of_numpy_array += numpy_array.nbytes + + size_of_numpy_array /= 1024**3 + size_of_tensordict /= 1024**3 + + message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB" + + if prefix: + message = f"{prefix}, " + message + print(message) + + def check_consistency(self): + """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch + We expose this function as a public one so that user can call themselves directly + """ + if self.batch is not None: + assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1" + + if self.non_tensor_batch is not None: + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray) + + if self.batch is not None and self.non_tensor_batch is not None and len(self.non_tensor_batch) != 0: + # TODO: we can actually lift this restriction if needed + assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty." + + batch_size = self.batch.batch_size[0] + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray), ( + f"data in the non_tensor_batch must be a numpy.array with dtype=object, but for " + f"{key=}, got {type(val)=}" + ) + assert val.shape[0] == batch_size, ( + f"key {key} length {len(val)} is not equal to batch size {batch_size}" + ) + + @classmethod + def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None): + """Create a DataProto from a dict of tensors and non_tensors""" + tensors = {} + non_tensors = {} + + for key, val in data.items(): + if isinstance(val, torch.Tensor): + tensors[key] = val + elif isinstance(val, np.ndarray): + non_tensors[key] = val + else: + raise ValueError(f"Unsupported type in data {type(val)}") + + return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) + + @classmethod + def from_dict( + cls, + tensors: Optional[dict[str, torch.Tensor]] = None, + non_tensors=None, + meta_info=None, + num_batch_dims=1, + ): + """Create a DataProto from a dict of tensors. This assumes that + 1. All the tensor in tensors have the same dim0 + 2. Only dim0 is the batch dim + """ + + assert num_batch_dims > 0, "num_batch_dims must be greater than zero" + if non_tensors is not None: + assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None." + + if tensors is None: + tensors = {} + if meta_info is None: + meta_info = {} + if non_tensors is None: + non_tensors = {} + + assert isinstance(non_tensors, dict) + + # get and check batch size + batch_size = None + pivot_key = None + for key, tensor in tensors.items(): + if batch_size is None: + batch_size = tensor.shape[:num_batch_dims] + pivot_key = key + else: + current_batch = tensor.shape[:num_batch_dims] + assert batch_size == current_batch, ( + f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. " + f"Got {pivot_key} has {batch_size}, {key} has {current_batch}" + ) + + for key, val in non_tensors.items(): + if not isinstance(val, np.ndarray): + non_tensors[key] = np.array(val, dtype=object) + + tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None + return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) + + @classmethod + def from_tensordict( + cls, + tensor_dict: TensorDict = None, + meta_info=None, + num_batch_dims=1, + ): + """Create a DataProto from a TensorDict. This assumes that + 1. All the tensor in tensor_dict have the same dim0 + 2. Only dim0 is the batch dim + """ + assert version.parse(tensordict.__version__) >= version.parse("0.10.0"), ( + "Build DataProto from TensorDict at least requires tensordict version 0.10.0" + ) + from tensordict import NonTensorData, NonTensorStack + + assert num_batch_dims > 0, "num_batch_dims must be greater than zero" + if not all(isinstance(val, torch.Tensor) for val in tensor_dict.values()): + assert num_batch_dims == 1, "only support num_batch_dims=1 when tensor_dict contains non tensor data." + + if meta_info is None: + meta_info = {} + batch = {} + non_tensor_batch = {} + batch_size = None + for key, val in tensor_dict.items(): + if isinstance(val, torch.Tensor): + batch[key] = val + if batch_size is None: + batch_size = val.shape[:num_batch_dims] + elif isinstance(val, NonTensorStack): + non_tensor_batch[key] = np.array([elem.data for elem in val], dtype=object) + elif isinstance(val, NonTensorData): + meta_info[key] = val.data + + return cls( + batch=TensorDict(batch, batch_size=batch_size), + non_tensor_batch=non_tensor_batch, + meta_info=meta_info, + ) + + def to(self, device) -> "DataProto": + """move the batch to device + + Args: + device (torch.device, str): torch device + + Returns: + DataProto: the current DataProto + + """ + if self.batch is not None: + self.batch = self.batch.to(device) + return self + + def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto": + """Select a subset of the DataProto via batch_keys and meta_info_keys + + Args: + batch_keys (list, optional): a list of strings indicating the keys in batch to select + meta_info_keys (list, optional): a list of keys indicating the meta info to select + + Returns: + DataProto: the DataProto with the selected batch_keys and meta_info_keys + """ + # TODO (zhangchi.usc1992) whether to copy + if batch_keys is not None: + batch_keys = tuple(batch_keys) + sub_batch = self.batch.select(*batch_keys) + else: + sub_batch = self.batch + + if non_tensor_batch_keys is not None: + non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys} + else: + non_tensor_batch = self.non_tensor_batch + + if deepcopy: + non_tensor_batch = copy.deepcopy(non_tensor_batch) + + if meta_info_keys is not None: + sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys} + else: + sub_meta_info = self.meta_info + + if deepcopy: + sub_meta_info = copy.deepcopy(sub_meta_info) + + return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) + + def select_idxs(self, idxs): + """ + Select specific indices from the DataProto. + + Args: + idxs (torch.Tensor or numpy.ndarray or list): Indices to select + + Returns: + DataProto: A new DataProto containing only the selected indices + """ + if isinstance(idxs, list): + idxs = torch.tensor(idxs) + if idxs.dtype != torch.bool: + idxs = idxs.type(torch.int32) + + if isinstance(idxs, np.ndarray): + idxs_np = idxs + idxs_torch = torch.from_numpy(idxs) + else: # torch.Tensor + idxs_torch = idxs + idxs_np = idxs.detach().cpu().numpy() + + batch_size = int(idxs_np.sum()) if idxs_np.dtype == bool else idxs_np.shape[0] + + if self.batch is not None: + # Use TensorDict's built-in indexing capabilities + selected_batch = TensorDict( + source={key: tensor[idxs_torch] for key, tensor in self.batch.items()}, + batch_size=(batch_size,), + device=self.batch.device, + ) + else: + selected_batch = None + + selected_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + selected_non_tensor[key] = val[idxs_np] + + return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info) + + def slice(self, start=None, end=None, step=None): + """ + Slice the DataProto and return a new DataProto object. + This is an improved version of direct slicing which returns a DataProtoItem. + + Args: + start (int, optional): Start index. Defaults to None (start from beginning). + end (int, optional): End index (exclusive). Defaults to None (go to end). + step (int, optional): Step size. Defaults to None (step=1). + + Returns: + DataProto: A new DataProto containing the sliced data + + Examples: + # Using the slice method directly + sliced_data = data_proto.slice(10, 20) + + # Using enhanced indexing (returns DataProto) + sliced_data = data_proto[10:20] + sliced_data = data_proto[::2] # Every other element + + # Using list indexing (returns DataProto) + indices = [1, 5, 10] + selected_data = data_proto[indices] + + # Single index still returns DataProtoItem + single_item = data_proto[5] + """ + # Create a slice object + slice_obj = slice(start, end, step) + + # Handle the batch data + if self.batch is not None: + # Use TensorDict's built-in slicing capabilities + sliced_batch = self.batch[slice_obj] + else: + sliced_batch = None + + # Handle the non-tensor batch data + sliced_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + sliced_non_tensor[key] = val[slice_obj] + + # Return a new DataProto object + return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info) + + def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto": + """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` + + Args: + batch_keys (list, optional): a list of strings indicating the keys in batch to pop + meta_info_keys (list, optional): a list of keys indicating the meta info to pop + + Returns: + DataProto: the DataProto with the poped batch_keys and meta_info_keys + """ + if batch_keys is None: + batch_keys = [] + if meta_info_keys is None: + meta_info_keys = [] + if non_tensor_batch_keys is None: + non_tensor_batch_keys = [] + + tensors = {} + # tensor batch + for key in batch_keys: + assert key in self.batch.keys() + tensors[key] = self.batch.pop(key) + non_tensors = {} + # non tensor batch + for key in non_tensor_batch_keys: + assert key in self.non_tensor_batch.keys() + non_tensors[key] = self.non_tensor_batch.pop(key) + meta_info = {} + for key in meta_info_keys: + assert key in self.meta_info.keys() + meta_info[key] = self.meta_info.pop(key) + return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) + + def rename(self, old_keys=None, new_keys=None) -> "DataProto": + """ + Note that this function only rename the key in the batch + """ + + def validate_input(keys): + if keys is not None: + if isinstance(keys, str): + keys = [keys] + elif isinstance(keys, list): + pass + else: + raise TypeError(f"keys must be a list or a string, but got {type(keys)}") + return keys + + old_keys = validate_input(old_keys) + new_keys = validate_input(new_keys) + + if len(new_keys) != len(old_keys): + raise ValueError( + f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}" + ) + + self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) + + return self + + def union(self, other: "DataProto") -> "DataProto": + """Union with another DataProto. Union batch and meta_info separately. + Throw an error if + + - there are conflict keys in batch and they are not equal + - the batch size of two data batch is not the same + - there are conflict keys in meta_info and they are not the same. + + Args: + other (DataProto): another DataProto to union + + Returns: + DataProto: the DataProto after union + """ + self.batch = union_tensor_dict(self.batch, other.batch) + self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) + self.meta_info = union_python_dict(self.meta_info, other.meta_info) + return self + + def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): + r"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch + dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details. + + + Args: + mini_batch_size (int): mini-batch size when iterating the dataset. We require that + ``batch.batch_size[0] % mini_batch_size == 0``. + epochs (int): number of epochs when iterating the dataset. + dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The + dataloader_kwargs is the kwargs passed to the DataLoader. + + Returns: + Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration + steps is ``self.batch.batch_size * epochs // mini_batch_size`` + """ + assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" + # we can directly create a dataloader from TensorDict + if dataloader_kwargs is None: + dataloader_kwargs = {} + + if seed is not None: + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = None + + assert isinstance(dataloader_kwargs, dict) + train_dataloader = DataLoader( + dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs + ) + + def get_data(): + for _ in range(epochs): + for d in train_dataloader: + d.meta_info = self.meta_info + yield d + + return iter(get_data()) + + def padding(self, padding_size, padding_candidate=""): + """Pad the DataProto by concating with padding_candidate.repeat(padding_size) + + Args: + padding_size (int): the number of repeated padding_candidate + padding_candidate: the item to be repeated and appended to the DataProto, only supporting ["first", "last"] + """ + if padding_size == 0: + return + padding_candidate = self.select_idxs([0 if padding_candidate == "first" else len(self) - 1]) + padding_part = padding_candidate.repeat(padding_size) + padded_dp = DataProto.concat([self, padding_part]) + self.batch = padded_dp.batch + self.non_tensor_batch = padded_dp.non_tensor_batch + + def chunk(self, chunks: int) -> list["DataProto"]: + """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. + + Args: + chunks (int): the number of chunks to split on dim=0 + + Returns: + List[DataProto]: a list of DataProto after splitting + """ + if not self.is_padding_enabled(): + assert len(self) % chunks == 0, ( + f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." + ) + + bsz_in_batch = None + if self.batch is not None: + batch_lst = self.batch.chunk(chunks=chunks, dim=0) + bsz_in_batch = np.array([batch.batch_size[0] for batch in batch_lst]) + chunk_indices = np.cumsum(bsz_in_batch)[:-1] + else: + batch_lst = [None for _ in range(chunks)] + + non_tensor_batch_lst = [{} for _ in range(chunks)] + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray) + if bsz_in_batch is not None: + non_tensor_lst = np.array_split(val, chunk_indices.tolist()) + else: + non_tensor_lst = np.array_split(val, chunks) + assert len(non_tensor_lst) == chunks + for i in range(chunks): + non_tensor_batch_lst[i][key] = non_tensor_lst[i] + + output = [] + for i in range(chunks): + output.append( + type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info) + ) + + return output + + def split(self, split_size: int) -> list["DataProto"]: + """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. + + Args: + split_size (int): the size of each split + + Returns: + List[DataProto]: a list of DataProto after splitting + """ + return [self[i : i + split_size] for i in range(0, len(self), split_size)] + + @staticmethod + def concat(data: list["DataProto"]) -> "DataProto": + """Concat a list of DataProto. The batch is concatenated among dim=0. + The meta_info is merged, with special handling for metrics from different workers. + + Args: + data (List[DataProto]): list of DataProto + + Returns: + DataProto: concatenated DataProto + """ + batch_lst = [] + for batch in data: + batch_lst.append(batch.batch) + new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None + + non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data]) + for key, val in non_tensor_batch.items(): + non_tensor_batch[key] = np.concatenate(val, axis=0) + + # Merge meta_info with special handling for metrics + merged_meta_info = {} + if data: + # Merge non-metric meta_info and aggregate metrics from all workers. + all_metrics = [] + for d in data: + for k, v in d.meta_info.items(): + if k == "metrics": + if v is not None: + if isinstance(v, list): + all_metrics.extend(v) + else: + all_metrics.append(v) + else: + if k in merged_meta_info: + # Ensure consistency for overlapping non-metric keys + assert merged_meta_info[k] == v, f"Conflicting values for meta_info key '{k}'" + else: + merged_meta_info[k] = v + + # Flatten list of dicts to dict of lists for consistent metrics structure + if all_metrics: + merged_meta_info["metrics"] = list_of_dict_to_dict_of_list(all_metrics) + + cls = type(data[0]) if len(data) > 0 else DataProto + return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=merged_meta_info) + + def reorder(self, indices): + """ + Note that this operation is in-place + """ + indices_np = indices.detach().numpy() + self.batch = self.batch[indices] + self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} + + def repeat(self, repeat_times=2, interleave=True): + """ + Repeat the batch data a specified number of times. + + Args: + repeat_times (int): Number of times to repeat the data. + interleave (bool): Whether to interleave the repeated data. + + Returns: + DataProto: A new DataProto with repeated data. + """ + if self.batch is not None: + if interleave: + # Interleave the data + repeated_tensors = { + key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + } + else: + # Stack the data + repeated_tensors = { + key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) + for key, tensor in self.batch.items() + } + + repeated_batch = TensorDict( + source=repeated_tensors, + batch_size=(self.batch.batch_size[0] * repeat_times,), + ) + else: + repeated_batch = None + + repeated_non_tensor_batch = {} + for key, val in self.non_tensor_batch.items(): + if interleave: + repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) + else: + repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) + + return type(self)( + batch=repeated_batch, + non_tensor_batch=repeated_non_tensor_batch, + meta_info=self.meta_info, + ) + + def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None): + """Split along the second dim into `n_split`, unfold it to the first dim (batch dim) + Useful in passing grouped tensors that doesn't want to be shuffled in dataset. + keys not in split_keys are repeated to match the shape + Note that if the `split_keys` is not provided, it will repeat all the keys in the second dim. + """ + if self.batch is not None: + unfolded_batch = {} + for key in self.batch.keys(): + if key in split_keys if split_keys is not None else False: + shape = list(self.batch[key].shape) + shape[0] = self.batch[key].shape[0] * n_split + shape[1] = self.batch[key].shape[1] // n_split + unfolded_batch[key] = self.batch[key].reshape(*shape) + else: + unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0) + # locate the `unfolded_batch` as a TensorDict on the same device as the original batch + unfolded_batch = TensorDict( + source=unfolded_batch, batch_size=(self.batch.batch_size[0] * n_split,), device=self.batch.device + ) + else: + unfolded_batch = None + + repeated_non_tensor_batch = {} + for key, val in self.non_tensor_batch.items(): + if key in split_keys: + shape = list(val.shape) + shape[0] = val.shape[0] * n_split + shape[1] = val.shape[1] // n_split + repeated_non_tensor_batch[key] = val.reshape(*shape) + else: + repeated_non_tensor_batch[key] = np.repeat(val, n_split, axis=0) + + return type(self)( + batch=unfolded_batch, + non_tensor_batch=repeated_non_tensor_batch, + meta_info=self.meta_info, + ) + + def sample_level_repeat(self, repeat_times): + """ + Repeat each row of the batch data a specified number of times. + + Args: + repeat_times (torch.tensor, list, tuple, ndarray): Number of times to repeat the data. + + Returns: + DataProto: A new DataProto with repeated data. + """ + if isinstance(repeat_times, tuple): + repeat_times = list(repeat_times) + elif isinstance(repeat_times, torch.Tensor): + assert len(repeat_times.shape) == 1 + repeat_times = repeat_times.tolist() + elif isinstance(repeat_times, np.ndarray): + assert len(repeat_times.shape) == 1 + repeat_times = repeat_times.tolist() + else: + assert isinstance(repeat_times, list), ( + f"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}" + ) + repeat_times = torch.tensor(repeat_times) + + if self.batch is not None: + # Interleave the data + repeated_tensors = { + key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + } + + repeated_batch = TensorDict( + source=repeated_tensors, + batch_size=(repeat_times.sum().item(),), + device=self.batch.device, + ) + else: + repeated_batch = None + + repeated_non_tensor_batch = {} + for key, val in self.non_tensor_batch.items(): + repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) + + return type(self)( + batch=repeated_batch, + non_tensor_batch=repeated_non_tensor_batch, + meta_info=self.meta_info, + ) + + def to_tensordict(self) -> TensorDict: + """Convert this DataProto to TensorDict. Note that this requires tensordict version at least 0.10 + + Returns: + + """ + assert parse_version(tensordict.__version__) >= parse_version("0.10"), ( + "Convert DataProto to TensorDict at least requires tensordict version 0.10" + ) + tensor_batch = self.batch.to_dict() + non_tensor_batch = self.non_tensor_batch + + from tensordict.tensorclass import NonTensorData, NonTensorStack + + common_keys = set(tensor_batch.keys()) & set(non_tensor_batch.keys()) + assert len(common_keys) == 0, f"tensor_batch and non_tensor_batch have common keys {common_keys}" + + for key, val in non_tensor_batch.items(): + assert isinstance(val, np.ndarray) + # Convert to NonTensorStack instead of plain list to handle nested structures + tensor_batch[key] = NonTensorStack.from_list([NonTensorData(item) for item in val]) + output = get_tensordict(tensor_dict=tensor_batch, non_tensor_dict=self.meta_info) + return output + + def get_data_info(self) -> str: + """Return formatted information about stored data with nested type details. + + Returns: + str: Formatted string showing tensor details and recursive metadata types + """ + info = ["batch"] + + for key, tensor in self.batch.items(): + if hasattr(tensor, "shape") and hasattr(tensor, "dtype") and hasattr(tensor, "device"): + info.append(f" {key}: {tuple(tensor.shape)} ({tensor.dtype}) {tensor.device}") + elif hasattr(tensor, "shape") and hasattr(tensor, "dtype"): + info.append(f" {key}: {tuple(tensor.shape)} ({tensor.dtype})") + else: + info.append(f" {key}: {type(tensor).__name__}") + + info.append("non_tensor_batch") + for key, array in self.non_tensor_batch.items(): + info.append(f" {key}: ndarray{array.shape} ({array.dtype})") + + info.append("meta_info") + for k, v in self.meta_info.items(): + type_info = self._get_type_info(v) + info.append(f" {k}: {type_info}") + + return "\n".join(info) + + def _get_type_info(self, value): + """Recursively get type information for nested structures""" + if isinstance(value, list): + elem_types = {self._get_type_info(v) for v in value[:3]} + return f"list[{'|'.join(elem_types) if elem_types else '...'}]" + if isinstance(value, tuple): + elem_types = [self._get_type_info(v) for v in value] + return f"tuple({', '.join(elem_types)})" + if isinstance(value, dict): + if not value: + return "dict" + k, v = next(iter(value.items())) + return f"dict[{self._get_type_info(k)}: {self._get_type_info(v)}]" + if isinstance(value, np.ndarray): + return f"ndarray{value.shape} ({value.dtype})" + return type(value).__name__ +