Skip to content

Basic sample sort #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion charmpandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,17 @@ def merge(self, other, on=None, left_on=None, right_on=None, how='inner'):
return result

def groupby(self, by):
return DataFrameGroupBy(self, by)
return DataFrameGroupBy(self, by)

def sort_values(self, by, ascending=True):
interface = get_interface()
result = DataFrame(None)

if isinstance(by, list) and len(by) > 1:
print("Sorting by multiple columns not supported yet. Sorting by first column only.")

if isinstance(by, str): # if `by` is a single field, convert to list
by = [by]

interface.sort_table(self.name, by, ascending, result.name)
return result
21 changes: 21 additions & 0 deletions charmpandas/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Operations(object):
fetch_size = 10
barrier = 11
reduction = 12
sort_values = 13


class GroupByOperations(object):
Expand Down Expand Up @@ -316,6 +317,26 @@ def concat_tables(self, tables, res):
cmd += gcmd
self.send_command_async(Handlers.async_handler, cmd)

def sort_table(self, table_name, by, ascending, result_name):
self.activity_handler()
cmd = self.get_header(self.group_epoch)

gcmd = self.get_deletion_header()
gcmd += to_bytes(Operations.sort_values, 'i')
gcmd += to_bytes(table_name, 'i')
gcmd += to_bytes(result_name, 'i')

gcmd += to_bytes(len(by), 'i')
for b in by:
gcmd += string_bytes(b)

gcmd += to_bytes(ascending, 'B')

cmd += to_bytes(len(gcmd), 'i')
cmd += gcmd
self.send_command_async(Handlers.async_group_handler, cmd)
self.group_epoch += 1

def reduction(self, name, field, op):
self.activity_handler()
cmd = self.get_header(self.epoch)
Expand Down
5 changes: 5 additions & 0 deletions src/messaging.ci
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ module messaging
char data[];
};

message SortTableMsg
{
char data[];
};

message RemoteTableMsg
{
char data[];
Expand Down
12 changes: 11 additions & 1 deletion src/messaging.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,24 @@ class RemoteTableMsg : public BaseTableDataMsg, public CMessage_RemoteTableMsg
class GatherTableDataMsg : public BaseTableDataMsg, public CMessage_GatherTableDataMsg
{
public:
int chareIdx;
int num_partitions;

GatherTableDataMsg(int epoch_, int size_, int num_partitions_)
GatherTableDataMsg(int epoch_, int size_, int chareIdx_, int num_partitions_)
: BaseTableDataMsg(epoch_, size_)
, chareIdx(chareIdx_)
, num_partitions(num_partitions_)
{}
};

class SortTableMsg : public BaseTableDataMsg, public CMessage_SortTableMsg
{
public:
SortTableMsg(int epoch_, int size_)
: BaseTableDataMsg(epoch_, size_)
{}
};

class RemoteJoinMsg : public BaseTableDataMsg, public CMessage_RemoteJoinMsg
{
public:
Expand Down
6 changes: 6 additions & 0 deletions src/partition.ci
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ module partition
}
};

entry void collect_samples(int num_samples, int64_t samples[num_samples]);

entry void receive_splitters(int num_splitters, int64_t splitters[num_splitters]);

entry void receive_sort_tables(SortTableMsg* msg);

entry [reductiontarget] void assign_keys(int num_elements, int global_hist[num_elements]);

entry void shuffle_data(std::vector<int> pe_map, std::vector<int> expected_loads);
Expand Down
181 changes: 166 additions & 15 deletions src/partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,13 @@ void Partition::operation_fetch(char* cmd)
auto table = clean_metadata(tables[table_name]);
BufferPtr out;
serialize(table, out);
msg = new (out->size()) GatherTableDataMsg(EPOCH, out->size(), num_partitions);
msg = new (out->size()) GatherTableDataMsg(EPOCH, out->size(), thisIndex, num_partitions);
std::memcpy(msg->data, out->data(), out->size());
}
else
{
//CkPrintf("Table not found on chare %i\n", thisIndex);
msg = new (0) GatherTableDataMsg(EPOCH, 0, num_partitions);
msg = new (0) GatherTableDataMsg(EPOCH, 0, thisIndex, num_partitions);
}
agg_proxy[0].gather_table(msg);
complete_operation();
Expand Down Expand Up @@ -632,7 +632,7 @@ void Partition::execute_command(int epoch, int size, char* cmd)
operation_reduction(cmd);
break;
}

default:
break;
}
Expand Down Expand Up @@ -822,6 +822,8 @@ Aggregator::Aggregator(CProxy_Main main_proxy_)
, expected_rows(0)
, EPOCH(0)
, next_temp_name(0)
, agg_samples_collected(0)
, sort_tables_collected(0)
{
if (MEM_LOGGING)
{
Expand Down Expand Up @@ -895,10 +897,16 @@ void Aggregator::gather_table(GatherTableDataMsg* msg)
//TablePtr table = deserialize(data, size);
auto it = gather_count.find(msg->epoch);
if (it == gather_count.end())
{
gather_count[msg->epoch] = 1;
gather_buffer[msg->epoch].resize(msg->num_partitions);
}
else
{
gather_count[msg->epoch]++;
gather_buffer[msg->epoch].push_back(msg);
}

gather_buffer[msg->epoch][msg->chareIdx] = msg;

if (gather_count[msg->epoch] == msg->num_partitions)
{
Expand Down Expand Up @@ -1062,6 +1070,137 @@ void Aggregator::operation_join(char* cmd)
start_join();
}

void Aggregator::operation_sort_values(char* cmd)
{
int table_name = extract<int>(cmd);
int result_name = extract<int>(cmd);
int nkeys = extract<int>(cmd);
std::vector<std::string> keys;

for (int i = 0; i < nkeys; i++)
{
int key_size = extract<int>(cmd);
keys.push_back(std::string(cmd, key_size));
cmd += key_size;
}

bool ascending = extract<bool>(cmd);

std::vector<arrow::compute::SortKey> sort_keys;
for (const std::string& key : keys)
{
sort_keys.push_back(
arrow::compute::SortKey(
arrow::FieldRef(key),
ascending ? arrow::compute::SortOrder::Ascending : arrow::compute::SortOrder::Descending
)
);
}

sort_values_opts = new SortValuesOptions(table_name, result_name, sort_keys);

// Sort local data.
tables[table_name] = get_local_table(table_name);
auto indices_result = arrow::compute::SortIndices(arrow::Datum(tables[table_name]), arrow::compute::SortOptions(sort_keys)).ValueOrDie();
auto sorted_table = arrow::compute::Take(tables[table_name], indices_result).ValueOrDie().table();

// Get samples from sorted table.
auto column = sorted_table->GetColumnByName(keys[0]);
int num_samples = CkNumPes() - 1;
std::vector<int64_t> samples(num_samples);
for (int i = 0; i < num_samples; i++)
{
int index = (i + 1) * column->length() / CkNumPes();
int64_t sample = std::dynamic_pointer_cast<arrow::Int64Scalar>(column->GetScalar(index).ValueOrDie())->value;
samples[i] = sample;
}

thisProxy[0].collect_samples(num_samples, samples.data());
}

void Aggregator::collect_samples(int num_samples, int64_t samples[num_samples])
{
CkAssert(CkMyPe() == 0);

for (int i = 0; i < num_samples; i++)
{
all_samples.push_back(samples[i]);
}

if (++agg_samples_collected == CkNumPes())
{
// Sort samples to find splitters.
std::sort(all_samples.begin(), all_samples.end());

std::vector<int64_t> splitters(CkNumPes() - 1);
for (int i = 0; i < CkNumPes() - 1; i++)
{
splitters[i] = all_samples[(i + 1) * all_samples.size() / CkNumPes()];
}

// Send splitters to all aggregators.
thisProxy.receive_splitters(splitters.size(), splitters.data());
}
}

void Aggregator::receive_splitters(int num_splitters, int64_t splitters[num_splitters])
{
assert(num_splitters == CkNumPes() - 1);

auto table = tables[sort_values_opts->table_name];

// For each PE, create a filter for the table based on the splitters.
for (int i = 0; i < CkNumPes(); i++) {
arrow::Expression mask;
arrow::Expression column_ref = arrow::compute::field_ref(sort_values_opts->sort_keys[0].target);
if (i == 0)
mask = arrow::compute::less(column_ref, arrow::compute::literal(splitters[0]));
else if (i == CkNumPes() - 1)
mask = arrow::compute::greater_equal(column_ref, arrow::compute::literal(splitters[i - 1]));
else
mask = arrow::compute::and_(
arrow::compute::greater_equal(column_ref, arrow::compute::literal(splitters[i - 1])),
arrow::compute::less(column_ref, arrow::compute::literal(splitters[i]))
);

arrow::acero::Declaration source{"table_source", arrow::acero::TableSourceNodeOptions{table}};
arrow::acero::Declaration filter{"filter", {source}, arrow::acero::FilterNodeOptions{mask}};
auto filtered_table = arrow::acero::DeclarationToTable(std::move(filter)).ValueOrDie();

BufferPtr out;
serialize(filtered_table, out);
SortTableMsg* msg = new (out->size()) SortTableMsg(EPOCH, out->size());
std::memcpy(msg->data, out->data(), out->size());

// If sorting descending order, send to the last PE first.
int receiver_idx = (sort_values_opts->sort_keys[0].order == arrow::compute::SortOrder::Descending)
? (CkNumPes() - 1 - i)
: i;

thisProxy[index_to_send].receive_sort_tables(msg);
}
}

void Aggregator::receive_sort_tables(SortTableMsg* msg)
{
// TODO: the received table data is stored in the message buffer and never freed.
auto received_table = deserialize(msg->data, msg->size);

auto it = tables.find(TEMP_TABLE_OFFSET + next_temp_name);
if (it == std::end(tables))
tables[TEMP_TABLE_OFFSET + next_temp_name] = received_table;
else
it->second = arrow::ConcatenateTables({it->second, received_table}).ValueOrDie();

if (++sort_tables_collected == CkNumPes())
{
auto indices_result = arrow::compute::SortIndices(arrow::Datum(tables[TEMP_TABLE_OFFSET + next_temp_name]), arrow::compute::SortOptions(sort_values_opts->sort_keys)).ValueOrDie();
tables[TEMP_TABLE_OFFSET + next_temp_name] = arrow::compute::Take(tables[TEMP_TABLE_OFFSET + next_temp_name], indices_result).ValueOrDie().table();
partition_table(tables[TEMP_TABLE_OFFSET + next_temp_name], sort_values_opts->result_name);
complete_sort_values();
}
}

void Aggregator::barrier_handler(int epoch)
{
CcsSendDelayedReply(fetch_reply[epoch], 0, NULL);
Expand Down Expand Up @@ -1099,6 +1238,12 @@ void Aggregator::execute_command(int epoch, int size, char* cmd)
break;
}

case Operation::SortValues:
{
operation_sort_values(cmd);
break;
}

default:
break;
}
Expand Down Expand Up @@ -1449,6 +1594,12 @@ void Aggregator::receive_shuffle_data(RedistTableMsg* msg)

void Aggregator::complete_operation()
{
for (int i = 0; i < num_local_chares; i++)
{
int index = local_chares[i];
partition_proxy[index].ckLocal()->complete_operation();
}

redist_table_names.clear();
tables.clear();
redist_tables.clear();
Expand All @@ -1467,12 +1618,6 @@ void Aggregator::complete_groupby()
//EPOCH++;
tables.erase(TEMP_TABLE_OFFSET + next_temp_name++);

for (int i = 0; i < num_local_chares; i++)
{
int index = local_chares[i];
partition_proxy[index].ckLocal()->complete_operation();
}

complete_operation();
}

Expand All @@ -1486,11 +1631,17 @@ void Aggregator::complete_join()
join_opts = nullptr;
//EPOCH++;

for (int i = 0; i < num_local_chares; i++)
{
int index = local_chares[i];
partition_proxy[index].ckLocal()->complete_operation();
}
complete_operation();
}

void Aggregator::complete_sort_values()
{
tables.erase(sort_values_opts->table_name);
tables.erase(TEMP_TABLE_OFFSET + next_temp_name++);
delete sort_values_opts;
sort_values_opts = nullptr;
agg_samples_collected = 0;
sort_tables_collected = 0;

complete_operation();
}
Expand Down
Loading