-
Notifications
You must be signed in to change notification settings - Fork 53
Add merge_many_datasets_as_delayed #243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
|
||
from collections import defaultdict | ||
from copy import copy | ||
from functools import partial | ||
from typing import Dict, List, Optional | ||
|
||
import dask | ||
from dask import delayed | ||
|
@@ -19,7 +20,7 @@ | |
delete_top_level_metadata, | ||
) | ||
from kartothek.io_components.gc import delete_files, dispatch_files_to_gc | ||
from kartothek.io_components.merge import align_datasets | ||
from kartothek.io_components.merge import align_datasets, align_datasets_many | ||
from kartothek.io_components.metapartition import ( | ||
SINGLE_TABLE, | ||
MetaPartition, | ||
|
@@ -38,6 +39,7 @@ | |
raise_if_dataset_exists, | ||
store_dataset_from_partitions, | ||
) | ||
from kartothek.serialization import filter_df_from_predicates | ||
|
||
from ._update import _update_dask_partitions_one_to_one | ||
from ._utils import ( | ||
|
@@ -142,6 +144,57 @@ def _load_and_merge_mps(mp_list, store, label_merger, metadata_merger, merge_tas | |
return mp | ||
|
||
|
||
def _load_and_merge_many_mps( | ||
mp_list, | ||
store, | ||
label_merger, | ||
metadata_merger, | ||
merge_tasks, | ||
is_dispatched: bool, | ||
predicates=None, | ||
columns: Optional[List[Dict[str, List[str]]]] = None, | ||
): | ||
if is_dispatched: | ||
if columns: | ||
mp_list = [ | ||
[mp.load_dataframes(store=store, columns=cols) for mp in mps] | ||
for mps, cols in zip(mp_list, columns) | ||
] | ||
else: | ||
mp_list = [ | ||
[mp.load_dataframes(store=store) for mp in mps] for mps in mp_list | ||
] | ||
mp_list = [ | ||
MetaPartition.merge_metapartitions(mps).concat_dataframes() | ||
for mps in mp_list | ||
] | ||
else: | ||
if columns: | ||
mp_list = [ | ||
mp.load_dataframes(store=store, columns=cols) | ||
for mp, cols in zip(mp_list, columns) | ||
] | ||
else: | ||
mp_list = [mp.load_dataframes(store=store) for mp in mp_list] | ||
|
||
mp = MetaPartition.merge_metapartitions( | ||
mp_list, label_merger=label_merger, metadata_merger=metadata_merger | ||
) | ||
|
||
for task in merge_tasks: | ||
mp = mp.merge_many_dataframes(**task) | ||
|
||
if predicates: | ||
new_data = copy(mp.data) | ||
new_data = { | ||
key: filter_df_from_predicates(df, predicates=predicates) | ||
for key, df in new_data.items() | ||
} | ||
mp = mp.copy(data=new_data) | ||
|
||
return mp | ||
|
||
|
||
@default_docs | ||
def merge_datasets_as_delayed( | ||
left_dataset_uuid, | ||
|
@@ -230,6 +283,91 @@ def merge_datasets_as_delayed( | |
return list(mps) | ||
|
||
|
||
@default_docs | ||
def merge_many_datasets_as_delayed( | ||
dataset_uuids: List[str], | ||
store, | ||
merge_tasks, | ||
match_how="exact", | ||
dispatch_by=None, | ||
label_merger=None, | ||
metadata_merger=None, | ||
predicates=None, | ||
columns: Optional[List[Dict[str, List[str]]]] = None, | ||
): | ||
""" | ||
A dask.delayed graph to perform the merge of two full kartothek datasets. | ||
|
||
Parameters | ||
---------- | ||
dataset_uuids : List[str] | ||
match_how : Union[str, Callable] | ||
Define the partition label matching scheme. | ||
Available implementations are: | ||
|
||
* first : The partitions of the first dataset are considered to be the base | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. related to the question above: do we really need different string-based join modes? |
||
partitions and **all** partitions of the remaining datasets are | ||
joined to the partitions of the first dataset. This should only be | ||
used if all but the first dataset contain very few partitions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does "few" mean? What happens if this is not the case? Please give the user more guidance and try to provide a more failure-proof API. |
||
* prefix_first : The labels of the partitions of the first dataset are | ||
considered to be the prefixes to the other datasets. | ||
* exact : All partition labels of each dataset need to be an exact match. | ||
* callable : A callable with signature func(labels: List[str]) which | ||
returns a boolean to determine if the partitions match. | ||
|
||
If True, an exact match of partition labels between the to-be-merged | ||
datasets is required in order to merge. | ||
If False (Default), the partition labels of the dataset with fewer | ||
partitions are interpreted as prefixes. | ||
merge_tasks : List[Dict] | ||
A list of merge tasks. Each item in this list is a dictionary giving | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does a merge task drop/consume its input tables? if not, I think this might be a memory issue. |
||
explicit instructions for a specific merge. | ||
Each dict should contain key/values: | ||
|
||
* 'output_label' : The table for the merged dataframe | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about the |
||
* `merge_func`: A callable with signature | ||
`merge_func(dfs, merge_kwargs)` to | ||
handle the data preprocessing and merging. | ||
* 'merge_kwargs' : The kwargs to be passed to the `merge_func` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not required, use a partial instead. |
||
|
||
Example: | ||
|
||
.. code:: | ||
|
||
>>> merge_tasks = [ | ||
... { | ||
... "tables": ["first_table", "second_table"], | ||
... "merge_func": func, | ||
... "merge_kwargs": {"kwargs of merge_func": ''}, | ||
... "output_label": 'merged_core_data' | ||
... }, | ||
... ] | ||
|
||
""" | ||
_check_callable(store) | ||
|
||
mps = align_datasets_many( | ||
dataset_uuids=dataset_uuids, | ||
store=store, | ||
match_how=match_how, | ||
dispatch_by=dispatch_by, | ||
predicates=predicates, | ||
) | ||
mps = map_delayed( | ||
_load_and_merge_many_mps, | ||
mps, | ||
store=store, | ||
label_merger=label_merger, | ||
metadata_merger=metadata_merger, | ||
merge_tasks=merge_tasks, | ||
is_dispatched=dispatch_by is not None, | ||
predicates=predicates, | ||
columns=columns, | ||
) | ||
|
||
return list(mps) | ||
|
||
|
||
def _load_and_concat_metapartitions_inner(mps, args, kwargs): | ||
return MetaPartition.concat_metapartitions( | ||
[mp.load_dataframes(*args, **kwargs) for mp in mps] | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,8 +1,18 @@ | ||||||
import logging | ||||||
from functools import partial, reduce | ||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, cast | ||||||
|
||||||
import pandas as pd | ||||||
|
||||||
from kartothek.core.dataset import DatasetMetadata | ||||||
from kartothek.core.factory import DatasetFactory | ||||||
from kartothek.io_components.metapartition import MetaPartition | ||||||
from kartothek.io_components.utils import _instantiate_store | ||||||
from kartothek.io_components.read import dispatch_metapartitions_from_factory | ||||||
from kartothek.io_components.utils import _instantiate_store, _make_callable | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from simplekv import KeyValueStore | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
linting fails: |
||||||
|
||||||
|
||||||
LOGGER = logging.getLogger(__name__) | ||||||
|
||||||
|
@@ -104,3 +114,109 @@ def align_datasets(left_dataset_uuid, right_dataset_uuid, store, match_how="exac | |||||
"found".format(p_1, first_dataset) | ||||||
) | ||||||
yield res | ||||||
|
||||||
|
||||||
def align_datasets_many( | ||||||
dataset_uuids: List[str], | ||||||
store, | ||||||
match_how: str = "exact", | ||||||
dispatch_by: Optional[List[str]] = None, | ||||||
predicates=None, | ||||||
): | ||||||
""" | ||||||
Determine dataset partition alignment | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
left_dataset_uuid : basestring | ||||||
right_dataset_uuid : basestring | ||||||
store : KeyValuestore or callable | ||||||
match_how : basestring or callable, {exact, prefix, all, callable} | ||||||
|
||||||
Yields | ||||||
------ | ||||||
list | ||||||
""" | ||||||
if len(dataset_uuids) < 2: | ||||||
raise ValueError("Need at least two datasets for merging.") | ||||||
dataset_factories = [ | ||||||
DatasetFactory( | ||||||
dataset_uuid=dataset_uuid, | ||||||
store_factory=cast(Callable[[], "KeyValueStore"], _make_callable(store)), | ||||||
load_schema=True, | ||||||
load_all_indices=False, | ||||||
load_dataset_metadata=True, | ||||||
).load_partition_indices() | ||||||
for dataset_uuid in dataset_uuids | ||||||
] | ||||||
|
||||||
store = _instantiate_store(store) | ||||||
mps = [ | ||||||
# TODO: Add predicates | ||||||
# We don't pass dispatch_by here as we will do the dispatching later | ||||||
list( | ||||||
dispatch_metapartitions_from_factory( | ||||||
dataset_factory=dataset_factory, predicates=predicates | ||||||
) | ||||||
) | ||||||
for dataset_factory in dataset_factories | ||||||
] | ||||||
|
||||||
if match_how == "first": | ||||||
if len(set(len(x) for x in mps)) != 1: | ||||||
raise RuntimeError("All datasets must have the same number of partitions") | ||||||
for mp_0 in mps[0]: | ||||||
for other_mps in zip(*mps[1:]): | ||||||
yield [mp_0] + list(other_mps) | ||||||
elif match_how == "prefix_first": | ||||||
# TODO: write a test which protects against the following scenario!! | ||||||
# Sort the partition labels by length of the labels, starting with the | ||||||
# labels which are the longest. This way we prevent label matching for | ||||||
# similar partitions, e.g. cluster_100 and cluster_1. This, of course, | ||||||
# works only as long as the internal loop removes elements which were | ||||||
# matched already (here improperly called stack) | ||||||
for mp_0 in mps[0]: | ||||||
res = [mp_0] | ||||||
label_0 = mp_0.label | ||||||
for dataset_i in range(1, len(mps)): | ||||||
for j, mp_i in enumerate(mps[dataset_i]): | ||||||
if mp_i.label.startswith(label_0): | ||||||
res.append(mp_i) | ||||||
del mps[dataset_i][j] | ||||||
break | ||||||
else: | ||||||
raise RuntimeError( | ||||||
f"Did not find a matching partition in dataset {dataset_uuids[dataset_i]} for partition {label_0}" | ||||||
) | ||||||
yield res | ||||||
elif match_how == "exact": | ||||||
raise NotImplementedError("exact") | ||||||
elif match_how == "dispatch_by": | ||||||
index_dfs = [] | ||||||
for i, factory in enumerate(dataset_factories): | ||||||
df = factory.get_indices_as_dataframe(dispatch_by, predicates=predicates) | ||||||
index_dfs.append( | ||||||
df.reset_index().rename( | ||||||
columns={"partition": f"partition_{i}"}, copy=False | ||||||
) | ||||||
) | ||||||
index_df = reduce(partial(pd.merge, on=dispatch_by), index_dfs) | ||||||
|
||||||
mps_by_label: List[Dict[str, MetaPartition]] = [] | ||||||
for mpx in mps: | ||||||
mps_by_label.append({}) | ||||||
for mp in mpx: | ||||||
mps_by_label[-1][mp.label] = mp | ||||||
|
||||||
for _, group in index_df.groupby(dispatch_by): | ||||||
res_nested: List[List[MetaPartition]] = [] | ||||||
for i in range(len(dataset_uuids)): | ||||||
res_nested.append( | ||||||
[ | ||||||
mps_by_label[i][label] | ||||||
for label in group[f"partition_{i}"].unique() | ||||||
] | ||||||
) | ||||||
yield res_nested | ||||||
else: | ||||||
raise NotImplementedError(f"matching with '{match_how}' is not supported") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the whole thing label-based and not index-based?