Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 30 additions & 34 deletions grain/_src/python/dataset/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,52 +300,48 @@ def pretty_format_summary(
return table.get_pretty_wrapped_summary()


def record_next_duration_if_output(next_fn):
"""Records the duration of the `__next__` call on the output iterator node.
# Input pipeline stage categories.
IPL_CAT_PREPROCESSING = "preprocessing"
IPL_CAT_READ = "read"
IPL_CAT_ENQUEUE = "enqueue"
IPL_CAT_UNKNOWN = "unknown"

Expected to be used as follows:
```
class MyMapDatasetIterator(DatasetIterator):
...
@stats.record_next_duration_if_output
def __next__(self):
...
```

def record_next_duration_if_output(stage_category: str = IPL_CAT_ENQUEUE):
"""Records the duration of the `__next__` call on the output iterator node.

Args:
next_fn: The `__next__` function to wrap.
stage_category: The category of the input pipeline stage.

Returns:
The wrapped `next_fn`.
A decorator that records the duration of the `__next__` call.
"""

@functools.wraps(next_fn)
def wrapper(iterator):
if _TRACE_ANNOTATION and _TRACE_ANNOTATION.is_enabled():
with _TRACE_ANNOTATION(
f"{iterator.__class__.__name__}.{next_fn.__name__}",
_ipl_stage_name=str(iterator),
_ipl_stage_id=id(iterator),
):
def inner_wrapper(next_fn):

@functools.wraps(next_fn)
def wrapper(iterator):
if _TRACE_ANNOTATION and _TRACE_ANNOTATION.is_enabled():
with _TRACE_ANNOTATION(
f"{iterator.__class__.__name__}.{next_fn.__name__}",
_ipl_stage_name=str(iterator),
_ipl_stage_id=id(iterator),
_ipl_stage_cat=stage_category,
):
start_time = time.perf_counter_ns()
result = next_fn(iterator)
else:
start_time = time.perf_counter_ns()
result = next_fn(iterator)
else:
start_time = time.perf_counter_ns()
result = next_fn(iterator)

if iterator._stats._is_output: # pylint:disable=protected-access
next_duration_ns = time.perf_counter_ns() - start_time
_next_duration_ns_histogram.Record(next_duration_ns)
return result

return wrapper
if iterator._stats._is_output: # pylint: disable=protected-access
next_duration_ns = time.perf_counter_ns() - start_time
_next_duration_ns_histogram.Record(next_duration_ns)
return result

return wrapper

# Input pipeline stage categories.
IPL_CAT_PREPROCESSING = "preprocessing"
IPL_CAT_READ = "read"
IPL_CAT_ENQUEUE = "enqueue"
IPL_CAT_UNKNOWN = "unknown"
return inner_wrapper


def trace_input_pipeline(stage_category: str = IPL_CAT_UNKNOWN, **trace_kwargs):
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/transformations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
self._drop_remainder = drop_remainder
self._batch_fn = batch_fn

@stats.record_next_duration_if_output
@stats.record_next_duration_if_output(stage_category=stats.IPL_CAT_ENQUEUE)
def __next__(self) -> T:
values = []
for _ in range(self._batch_size):
Expand Down
4 changes: 3 additions & 1 deletion grain/_src/python/dataset/transformations/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def _threshold_checker(self):
raise_threshold=self._ctx.dataset_options.filter_raise_threshold_ratio,
)

@dataset_stats.record_next_duration_if_output
@dataset_stats.record_next_duration_if_output(
stage_category=dataset_stats.IPL_CAT_ENQUEUE
)
def __next__(self):
value = None
passed_filter = False
Expand Down
4 changes: 3 additions & 1 deletion grain/_src/python/dataset/transformations/flatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def __init__(
def _has_consumed_all_buffer_elements(self):
return self._next_index_in_buffer >= len(self._buffer)

@dataset_stats.record_next_duration_if_output
@dataset_stats.record_next_duration_if_output(
stage_category=dataset_stats.IPL_CAT_ENQUEUE
)
def __next__(self):
timer = dataset_stats.Timer()
while self._has_consumed_all_buffer_elements():
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
None
] * self._cycle_length

@stats.record_next_duration_if_output
@stats.record_next_duration_if_output(stage_category=stats.IPL_CAT_ENQUEUE)
def __next__(self) -> T:
while True:
if iterator_to_use := self._iterators_in_use[self._next_index_in_cycle]:
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/transformations/limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self._count = count
self._count_elements_read = 0

@stats.record_next_duration_if_output
@stats.record_next_duration_if_output(stage_category=stats.IPL_CAT_ENQUEUE)
def __next__(self):
if self._count_elements_read >= self._count:
raise StopIteration
Expand Down
6 changes: 3 additions & 3 deletions grain/_src/python/dataset/transformations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(
self._map_fn = map_fn
self._transform_name = transform_name

@stats.record_next_duration_if_output
@stats.record_next_duration_if_output(stage_category=stats.IPL_CAT_ENQUEUE)
def __next__(self):
element = next(self._parent)
with self._stats.record_self_time():
Expand Down Expand Up @@ -245,7 +245,7 @@ def __init__(
self._rng = np.random.Generator(np.random.Philox(seed))
self._transform_name = transform_name

@stats.record_next_duration_if_output
@stats.record_next_duration_if_output(stage_category=stats.IPL_CAT_ENQUEUE)
def __next__(self):
element = next(self._parent)
with self._stats.record_self_time():
Expand Down Expand Up @@ -291,7 +291,7 @@ def __init__(
self._transform_name = transform_name
self._counter = 0

@stats.record_next_duration_if_output
@stats.record_next_duration_if_output(stage_category=stats.IPL_CAT_ENQUEUE)
def __next__(self):
element = next(self._parent)
with self._stats.record_self_time():
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/transformations/mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
self._index = 0
self._stop = False

@stats.record_next_duration_if_output
@stats.record_next_duration_if_output(stage_category=stats.IPL_CAT_ENQUEUE)
def __next__(self):
if self._stop:
# Although there may be elements available in some parent datasets, do not
Expand Down
4 changes: 3 additions & 1 deletion grain/_src/python/dataset/transformations/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def _finalize_current_batch(self, element_for_shapes):
padding_struct=self._padding_struct,
)

@dataset_stats.record_next_duration_if_output
@dataset_stats.record_next_duration_if_output(
stage_category=dataset_stats.IPL_CAT_ENQUEUE
)
def __next__(self):
timer = dataset_stats.Timer()
if self._packed_batch is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _maybe_add_to_buffer(
tokens_in_buffer[k] += v
return remainder

@stats.record_next_duration_if_output
@stats.record_next_duration_if_output(stage_category=stats.IPL_CAT_ENQUEUE)
def __next__(self):
if self._packed_elements:
self._state.elements_from_buffer_after_checkpoint += 1
Expand Down
12 changes: 9 additions & 3 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def _threshold_checker(self):
raise_threshold=self._ctx.dataset_options.filter_raise_threshold_ratio,
)

@dataset_stats.record_next_duration_if_output
@dataset_stats.record_next_duration_if_output(
stage_category=dataset_stats.IPL_CAT_ENQUEUE
)
def __next__(self) -> T:
# The time recorded here is the time spent in prefetch node to return an
# element, including the time spent in parent node.
Expand Down Expand Up @@ -670,7 +672,9 @@ def _stats(self):
def __iter__(self) -> dataset.DatasetIterator[T]:
return self

@dataset_stats.record_next_duration_if_output
@dataset_stats.record_next_duration_if_output(
stage_category=dataset_stats.IPL_CAT_ENQUEUE
)
def __next__(self) -> T:
self._ensure_iterator_initialized()
# The time recorded here is the time spent in prefetch node to return an
Expand Down Expand Up @@ -881,7 +885,9 @@ def start_prefetch(self):
)
self._prefetch_thread.start()

@dataset_stats.record_next_duration_if_output
@dataset_stats.record_next_duration_if_output(
stage_category=dataset_stats.IPL_CAT_ENQUEUE
)
def __next__(self):
self.start_prefetch()
element, state, err = self._buffer.get()
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/transformations/rebatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _flatten(self, pytree_element):
self._treedef = pytree_element
return tree_lib.flatten(pytree_element)

@stats.record_next_duration_if_output
@stats.record_next_duration_if_output(stage_category=stats.IPL_CAT_ENQUEUE)
def __next__(self):
timer = stats.Timer()

Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/transformations/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _fill_and_shuffle_window(self):
seed=self._global_seed + self._window_index, window=self._window
)

@stats.record_next_duration_if_output
@stats.record_next_duration_if_output(stage_category=stats.IPL_CAT_ENQUEUE)
def __next__(self):
# Window is empty, fill up the next window.
if not self._window:
Expand Down
4 changes: 3 additions & 1 deletion grain/_src/python/dataset/transformations/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def __init__(
super().__init__([p.__iter__() for p in parents])
self._strict = strict

@dataset_stats.record_next_duration_if_output
@dataset_stats.record_next_duration_if_output(
stage_category=dataset_stats.IPL_CAT_ENQUEUE
)
def __next__(self) -> tuple[T, ...]:
with self._stats.record_self_time():
# Can't use for a `for` loop because we need to raise StopIteration from
Expand Down