diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index 6b5049a020..5ff9206547 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -737,6 +737,7 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: role=message["role"], content=content, eot=eot, + ipython=message["ipython"] if "ipython" in message else False, ), ) mask_messages(updated_messages, self.masking_strategy)