Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 140 additions & 2 deletions kartothek/io/dask/delayed.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-


from collections import defaultdict
from copy import copy
from functools import partial
from typing import Dict, List, Optional

import dask
from dask import delayed
Expand All @@ -19,7 +20,7 @@
delete_top_level_metadata,
)
from kartothek.io_components.gc import delete_files, dispatch_files_to_gc
from kartothek.io_components.merge import align_datasets
from kartothek.io_components.merge import align_datasets, align_datasets_many
from kartothek.io_components.metapartition import (
SINGLE_TABLE,
MetaPartition,
Expand All @@ -38,6 +39,7 @@
raise_if_dataset_exists,
store_dataset_from_partitions,
)
from kartothek.serialization import filter_df_from_predicates

from ._update import _update_dask_partitions_one_to_one
from ._utils import (
Expand Down Expand Up @@ -142,6 +144,57 @@ def _load_and_merge_mps(mp_list, store, label_merger, metadata_merger, merge_tas
return mp


def _load_and_merge_many_mps(
mp_list,
store,
label_merger,
metadata_merger,
merge_tasks,
is_dispatched: bool,
predicates=None,
columns: Optional[List[Dict[str, List[str]]]] = None,
):
if is_dispatched:
if columns:
mp_list = [
[mp.load_dataframes(store=store, columns=cols) for mp in mps]
for mps, cols in zip(mp_list, columns)
]
else:
mp_list = [
[mp.load_dataframes(store=store) for mp in mps] for mps in mp_list
]
mp_list = [
MetaPartition.merge_metapartitions(mps).concat_dataframes()
for mps in mp_list
]
else:
if columns:
mp_list = [
mp.load_dataframes(store=store, columns=cols)
for mp, cols in zip(mp_list, columns)
]
else:
mp_list = [mp.load_dataframes(store=store) for mp in mp_list]

mp = MetaPartition.merge_metapartitions(
mp_list, label_merger=label_merger, metadata_merger=metadata_merger
)

for task in merge_tasks:
mp = mp.merge_many_dataframes(**task)

if predicates:
new_data = copy(mp.data)
new_data = {
key: filter_df_from_predicates(df, predicates=predicates)
for key, df in new_data.items()
}
mp = mp.copy(data=new_data)

return mp


@default_docs
def merge_datasets_as_delayed(
left_dataset_uuid,
Expand Down Expand Up @@ -230,6 +283,91 @@ def merge_datasets_as_delayed(
return list(mps)


@default_docs
def merge_many_datasets_as_delayed(
dataset_uuids: List[str],
store,
merge_tasks,
match_how="exact",
dispatch_by=None,
label_merger=None,
metadata_merger=None,
predicates=None,
columns: Optional[List[Dict[str, List[str]]]] = None,
):
"""
A dask.delayed graph to perform the merge of two full kartothek datasets.

Parameters
----------
dataset_uuids : List[str]
match_how : Union[str, Callable]
Define the partition label matching scheme.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the whole thing label-based and not index-based?

Available implementations are:

* first : The partitions of the first dataset are considered to be the base
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to the question above: do we really need different string-based join modes?

partitions and **all** partitions of the remaining datasets are
joined to the partitions of the first dataset. This should only be
used if all but the first dataset contain very few partitions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "few" mean? What happens if this is not the case? Please give the user more guidance and try to provide a more failure-proof API.

* prefix_first : The labels of the partitions of the first dataset are
considered to be the prefixes to the other datasets.
* exact : All partition labels of each dataset need to be an exact match.
* callable : A callable with signature func(labels: List[str]) which
returns a boolean to determine if the partitions match.

If True, an exact match of partition labels between the to-be-merged
datasets is required in order to merge.
If False (Default), the partition labels of the dataset with fewer
partitions are interpreted as prefixes.
merge_tasks : List[Dict]
A list of merge tasks. Each item in this list is a dictionary giving
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does a merge task drop/consume its input tables? if not, I think this might be a memory issue.

explicit instructions for a specific merge.
Each dict should contain key/values:

* 'output_label' : The table for the merged dataframe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about the tables key from the example below?

* `merge_func`: A callable with signature
`merge_func(dfs, merge_kwargs)` to
handle the data preprocessing and merging.
* 'merge_kwargs' : The kwargs to be passed to the `merge_func`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not required, use a partial instead.


Example:

.. code::

>>> merge_tasks = [
... {
... "tables": ["first_table", "second_table"],
... "merge_func": func,
... "merge_kwargs": {"kwargs of merge_func": ''},
... "output_label": 'merged_core_data'
... },
... ]

"""
_check_callable(store)

mps = align_datasets_many(
dataset_uuids=dataset_uuids,
store=store,
match_how=match_how,
dispatch_by=dispatch_by,
predicates=predicates,
)
mps = map_delayed(
_load_and_merge_many_mps,
mps,
store=store,
label_merger=label_merger,
metadata_merger=metadata_merger,
merge_tasks=merge_tasks,
is_dispatched=dispatch_by is not None,
predicates=predicates,
columns=columns,
)

return list(mps)


def _load_and_concat_metapartitions_inner(mps, args, kwargs):
return MetaPartition.concat_metapartitions(
[mp.load_dataframes(*args, **kwargs) for mp in mps]
Expand Down
118 changes: 117 additions & 1 deletion kartothek/io_components/merge.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import logging
from functools import partial, reduce
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, cast

import pandas as pd

from kartothek.core.dataset import DatasetMetadata
from kartothek.core.factory import DatasetFactory
from kartothek.io_components.metapartition import MetaPartition
from kartothek.io_components.utils import _instantiate_store
from kartothek.io_components.read import dispatch_metapartitions_from_factory
from kartothek.io_components.utils import _instantiate_store, _make_callable

if TYPE_CHECKING:
from simplekv import KeyValueStore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from simplekv import KeyValueStore
from simplekv import KeyValueStore # noqa: F401

linting fails: kartothek/io_components/merge.py:14:5: F401 'simplekv.KeyValueStore' imported but unused



LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,3 +114,109 @@ def align_datasets(left_dataset_uuid, right_dataset_uuid, store, match_how="exac
"found".format(p_1, first_dataset)
)
yield res


def align_datasets_many(
dataset_uuids: List[str],
store,
match_how: str = "exact",
dispatch_by: Optional[List[str]] = None,
predicates=None,
):
"""
Determine dataset partition alignment

Parameters
----------
left_dataset_uuid : basestring
right_dataset_uuid : basestring
store : KeyValuestore or callable
match_how : basestring or callable, {exact, prefix, all, callable}

Yields
------
list
"""
if len(dataset_uuids) < 2:
raise ValueError("Need at least two datasets for merging.")
dataset_factories = [
DatasetFactory(
dataset_uuid=dataset_uuid,
store_factory=cast(Callable[[], "KeyValueStore"], _make_callable(store)),
load_schema=True,
load_all_indices=False,
load_dataset_metadata=True,
).load_partition_indices()
for dataset_uuid in dataset_uuids
]

store = _instantiate_store(store)
mps = [
# TODO: Add predicates
# We don't pass dispatch_by here as we will do the dispatching later
list(
dispatch_metapartitions_from_factory(
dataset_factory=dataset_factory, predicates=predicates
)
)
for dataset_factory in dataset_factories
]

if match_how == "first":
if len(set(len(x) for x in mps)) != 1:
raise RuntimeError("All datasets must have the same number of partitions")
for mp_0 in mps[0]:
for other_mps in zip(*mps[1:]):
yield [mp_0] + list(other_mps)
elif match_how == "prefix_first":
# TODO: write a test which protects against the following scenario!!
# Sort the partition labels by length of the labels, starting with the
# labels which are the longest. This way we prevent label matching for
# similar partitions, e.g. cluster_100 and cluster_1. This, of course,
# works only as long as the internal loop removes elements which were
# matched already (here improperly called stack)
for mp_0 in mps[0]:
res = [mp_0]
label_0 = mp_0.label
for dataset_i in range(1, len(mps)):
for j, mp_i in enumerate(mps[dataset_i]):
if mp_i.label.startswith(label_0):
res.append(mp_i)
del mps[dataset_i][j]
break
else:
raise RuntimeError(
f"Did not find a matching partition in dataset {dataset_uuids[dataset_i]} for partition {label_0}"
)
yield res
elif match_how == "exact":
raise NotImplementedError("exact")
elif match_how == "dispatch_by":
index_dfs = []
for i, factory in enumerate(dataset_factories):
df = factory.get_indices_as_dataframe(dispatch_by, predicates=predicates)
index_dfs.append(
df.reset_index().rename(
columns={"partition": f"partition_{i}"}, copy=False
)
)
index_df = reduce(partial(pd.merge, on=dispatch_by), index_dfs)

mps_by_label: List[Dict[str, MetaPartition]] = []
for mpx in mps:
mps_by_label.append({})
for mp in mpx:
mps_by_label[-1][mp.label] = mp

for _, group in index_df.groupby(dispatch_by):
res_nested: List[List[MetaPartition]] = []
for i in range(len(dataset_uuids)):
res_nested.append(
[
mps_by_label[i][label]
for label in group[f"partition_{i}"].unique()
]
)
yield res_nested
else:
raise NotImplementedError(f"matching with '{match_how}' is not supported")
Loading