From 921f549354601dc47a25490c9a73c7045525e46f Mon Sep 17 00:00:00 2001 From: Andy Min Date: Fri, 9 May 2025 10:47:07 -0500 Subject: [PATCH] Basic sample sort --- charmpandas/dataframe.py | 15 +++- charmpandas/interface.py | 21 +++++ src/messaging.ci | 5 ++ src/messaging.hpp | 12 ++- src/partition.ci | 6 ++ src/partition.cpp | 181 +++++++++++++++++++++++++++++++++++---- src/partition.hpp | 30 +++++++ src/utils.hpp | 3 +- 8 files changed, 255 insertions(+), 18 deletions(-) diff --git a/charmpandas/dataframe.py b/charmpandas/dataframe.py index 0d680c3..cf9652f 100644 --- a/charmpandas/dataframe.py +++ b/charmpandas/dataframe.py @@ -209,4 +209,17 @@ def merge(self, other, on=None, left_on=None, right_on=None, how='inner'): return result def groupby(self, by): - return DataFrameGroupBy(self, by) \ No newline at end of file + return DataFrameGroupBy(self, by) + + def sort_values(self, by, ascending=True): + interface = get_interface() + result = DataFrame(None) + + if isinstance(by, list) and len(by) > 1: + print("Sorting by multiple columns not supported yet. Sorting by first column only.") + + if isinstance(by, str): # if `by` is a single field, convert to list + by = [by] + + interface.sort_table(self.name, by, ascending, result.name) + return result \ No newline at end of file diff --git a/charmpandas/interface.py b/charmpandas/interface.py index 3a66cfb..594ba43 100644 --- a/charmpandas/interface.py +++ b/charmpandas/interface.py @@ -65,6 +65,7 @@ class Operations(object): fetch_size = 10 barrier = 11 reduction = 12 + sort_values = 13 class GroupByOperations(object): @@ -316,6 +317,26 @@ def concat_tables(self, tables, res): cmd += gcmd self.send_command_async(Handlers.async_handler, cmd) + def sort_table(self, table_name, by, ascending, result_name): + self.activity_handler() + cmd = self.get_header(self.group_epoch) + + gcmd = self.get_deletion_header() + gcmd += to_bytes(Operations.sort_values, 'i') + gcmd += to_bytes(table_name, 'i') + gcmd += to_bytes(result_name, 'i') + + gcmd += to_bytes(len(by), 'i') + for b in by: + gcmd += string_bytes(b) + + gcmd += to_bytes(ascending, 'B') + + cmd += to_bytes(len(gcmd), 'i') + cmd += gcmd + self.send_command_async(Handlers.async_group_handler, cmd) + self.group_epoch += 1 + def reduction(self, name, field, op): self.activity_handler() cmd = self.get_header(self.epoch) diff --git a/src/messaging.ci b/src/messaging.ci index 8c75edb..1832b3b 100644 --- a/src/messaging.ci +++ b/src/messaging.ci @@ -11,6 +11,11 @@ module messaging char data[]; }; + message SortTableMsg + { + char data[]; + }; + message RemoteTableMsg { char data[]; diff --git a/src/messaging.hpp b/src/messaging.hpp index 3d68ba4..4a2a9ab 100644 --- a/src/messaging.hpp +++ b/src/messaging.hpp @@ -110,14 +110,24 @@ class RemoteTableMsg : public BaseTableDataMsg, public CMessage_RemoteTableMsg class GatherTableDataMsg : public BaseTableDataMsg, public CMessage_GatherTableDataMsg { public: + int chareIdx; int num_partitions; - GatherTableDataMsg(int epoch_, int size_, int num_partitions_) + GatherTableDataMsg(int epoch_, int size_, int chareIdx_, int num_partitions_) : BaseTableDataMsg(epoch_, size_) + , chareIdx(chareIdx_) , num_partitions(num_partitions_) {} }; +class SortTableMsg : public BaseTableDataMsg, public CMessage_SortTableMsg +{ +public: + SortTableMsg(int epoch_, int size_) + : BaseTableDataMsg(epoch_, size_) + {} +}; + class RemoteJoinMsg : public BaseTableDataMsg, public CMessage_RemoteJoinMsg { public: diff --git a/src/partition.ci b/src/partition.ci index 8b3a77a..3ac525d 100644 --- a/src/partition.ci +++ b/src/partition.ci @@ -30,6 +30,12 @@ module partition } }; + entry void collect_samples(int num_samples, int64_t samples[num_samples]); + + entry void receive_splitters(int num_splitters, int64_t splitters[num_splitters]); + + entry void receive_sort_tables(SortTableMsg* msg); + entry [reductiontarget] void assign_keys(int num_elements, int global_hist[num_elements]); entry void shuffle_data(std::vector pe_map, std::vector expected_loads); diff --git a/src/partition.cpp b/src/partition.cpp index 784bcd9..9f275e6 100644 --- a/src/partition.cpp +++ b/src/partition.cpp @@ -345,13 +345,13 @@ void Partition::operation_fetch(char* cmd) auto table = clean_metadata(tables[table_name]); BufferPtr out; serialize(table, out); - msg = new (out->size()) GatherTableDataMsg(EPOCH, out->size(), num_partitions); + msg = new (out->size()) GatherTableDataMsg(EPOCH, out->size(), thisIndex, num_partitions); std::memcpy(msg->data, out->data(), out->size()); } else { //CkPrintf("Table not found on chare %i\n", thisIndex); - msg = new (0) GatherTableDataMsg(EPOCH, 0, num_partitions); + msg = new (0) GatherTableDataMsg(EPOCH, 0, thisIndex, num_partitions); } agg_proxy[0].gather_table(msg); complete_operation(); @@ -632,7 +632,7 @@ void Partition::execute_command(int epoch, int size, char* cmd) operation_reduction(cmd); break; } - + default: break; } @@ -822,6 +822,8 @@ Aggregator::Aggregator(CProxy_Main main_proxy_) , expected_rows(0) , EPOCH(0) , next_temp_name(0) + , agg_samples_collected(0) + , sort_tables_collected(0) { if (MEM_LOGGING) { @@ -895,10 +897,16 @@ void Aggregator::gather_table(GatherTableDataMsg* msg) //TablePtr table = deserialize(data, size); auto it = gather_count.find(msg->epoch); if (it == gather_count.end()) + { gather_count[msg->epoch] = 1; + gather_buffer[msg->epoch].resize(msg->num_partitions); + } else + { gather_count[msg->epoch]++; - gather_buffer[msg->epoch].push_back(msg); + } + + gather_buffer[msg->epoch][msg->chareIdx] = msg; if (gather_count[msg->epoch] == msg->num_partitions) { @@ -1062,6 +1070,137 @@ void Aggregator::operation_join(char* cmd) start_join(); } +void Aggregator::operation_sort_values(char* cmd) +{ + int table_name = extract(cmd); + int result_name = extract(cmd); + int nkeys = extract(cmd); + std::vector keys; + + for (int i = 0; i < nkeys; i++) + { + int key_size = extract(cmd); + keys.push_back(std::string(cmd, key_size)); + cmd += key_size; + } + + bool ascending = extract(cmd); + + std::vector sort_keys; + for (const std::string& key : keys) + { + sort_keys.push_back( + arrow::compute::SortKey( + arrow::FieldRef(key), + ascending ? arrow::compute::SortOrder::Ascending : arrow::compute::SortOrder::Descending + ) + ); + } + + sort_values_opts = new SortValuesOptions(table_name, result_name, sort_keys); + + // Sort local data. + tables[table_name] = get_local_table(table_name); + auto indices_result = arrow::compute::SortIndices(arrow::Datum(tables[table_name]), arrow::compute::SortOptions(sort_keys)).ValueOrDie(); + auto sorted_table = arrow::compute::Take(tables[table_name], indices_result).ValueOrDie().table(); + + // Get samples from sorted table. + auto column = sorted_table->GetColumnByName(keys[0]); + int num_samples = CkNumPes() - 1; + std::vector samples(num_samples); + for (int i = 0; i < num_samples; i++) + { + int index = (i + 1) * column->length() / CkNumPes(); + int64_t sample = std::dynamic_pointer_cast(column->GetScalar(index).ValueOrDie())->value; + samples[i] = sample; + } + + thisProxy[0].collect_samples(num_samples, samples.data()); +} + +void Aggregator::collect_samples(int num_samples, int64_t samples[num_samples]) +{ + CkAssert(CkMyPe() == 0); + + for (int i = 0; i < num_samples; i++) + { + all_samples.push_back(samples[i]); + } + + if (++agg_samples_collected == CkNumPes()) + { + // Sort samples to find splitters. + std::sort(all_samples.begin(), all_samples.end()); + + std::vector splitters(CkNumPes() - 1); + for (int i = 0; i < CkNumPes() - 1; i++) + { + splitters[i] = all_samples[(i + 1) * all_samples.size() / CkNumPes()]; + } + + // Send splitters to all aggregators. + thisProxy.receive_splitters(splitters.size(), splitters.data()); + } +} + +void Aggregator::receive_splitters(int num_splitters, int64_t splitters[num_splitters]) +{ + assert(num_splitters == CkNumPes() - 1); + + auto table = tables[sort_values_opts->table_name]; + + // For each PE, create a filter for the table based on the splitters. + for (int i = 0; i < CkNumPes(); i++) { + arrow::Expression mask; + arrow::Expression column_ref = arrow::compute::field_ref(sort_values_opts->sort_keys[0].target); + if (i == 0) + mask = arrow::compute::less(column_ref, arrow::compute::literal(splitters[0])); + else if (i == CkNumPes() - 1) + mask = arrow::compute::greater_equal(column_ref, arrow::compute::literal(splitters[i - 1])); + else + mask = arrow::compute::and_( + arrow::compute::greater_equal(column_ref, arrow::compute::literal(splitters[i - 1])), + arrow::compute::less(column_ref, arrow::compute::literal(splitters[i])) + ); + + arrow::acero::Declaration source{"table_source", arrow::acero::TableSourceNodeOptions{table}}; + arrow::acero::Declaration filter{"filter", {source}, arrow::acero::FilterNodeOptions{mask}}; + auto filtered_table = arrow::acero::DeclarationToTable(std::move(filter)).ValueOrDie(); + + BufferPtr out; + serialize(filtered_table, out); + SortTableMsg* msg = new (out->size()) SortTableMsg(EPOCH, out->size()); + std::memcpy(msg->data, out->data(), out->size()); + + // If sorting descending order, send to the last PE first. + int receiver_idx = (sort_values_opts->sort_keys[0].order == arrow::compute::SortOrder::Descending) + ? (CkNumPes() - 1 - i) + : i; + + thisProxy[index_to_send].receive_sort_tables(msg); + } +} + +void Aggregator::receive_sort_tables(SortTableMsg* msg) +{ + // TODO: the received table data is stored in the message buffer and never freed. + auto received_table = deserialize(msg->data, msg->size); + + auto it = tables.find(TEMP_TABLE_OFFSET + next_temp_name); + if (it == std::end(tables)) + tables[TEMP_TABLE_OFFSET + next_temp_name] = received_table; + else + it->second = arrow::ConcatenateTables({it->second, received_table}).ValueOrDie(); + + if (++sort_tables_collected == CkNumPes()) + { + auto indices_result = arrow::compute::SortIndices(arrow::Datum(tables[TEMP_TABLE_OFFSET + next_temp_name]), arrow::compute::SortOptions(sort_values_opts->sort_keys)).ValueOrDie(); + tables[TEMP_TABLE_OFFSET + next_temp_name] = arrow::compute::Take(tables[TEMP_TABLE_OFFSET + next_temp_name], indices_result).ValueOrDie().table(); + partition_table(tables[TEMP_TABLE_OFFSET + next_temp_name], sort_values_opts->result_name); + complete_sort_values(); + } +} + void Aggregator::barrier_handler(int epoch) { CcsSendDelayedReply(fetch_reply[epoch], 0, NULL); @@ -1099,6 +1238,12 @@ void Aggregator::execute_command(int epoch, int size, char* cmd) break; } + case Operation::SortValues: + { + operation_sort_values(cmd); + break; + } + default: break; } @@ -1449,6 +1594,12 @@ void Aggregator::receive_shuffle_data(RedistTableMsg* msg) void Aggregator::complete_operation() { + for (int i = 0; i < num_local_chares; i++) + { + int index = local_chares[i]; + partition_proxy[index].ckLocal()->complete_operation(); + } + redist_table_names.clear(); tables.clear(); redist_tables.clear(); @@ -1467,12 +1618,6 @@ void Aggregator::complete_groupby() //EPOCH++; tables.erase(TEMP_TABLE_OFFSET + next_temp_name++); - for (int i = 0; i < num_local_chares; i++) - { - int index = local_chares[i]; - partition_proxy[index].ckLocal()->complete_operation(); - } - complete_operation(); } @@ -1486,11 +1631,17 @@ void Aggregator::complete_join() join_opts = nullptr; //EPOCH++; - for (int i = 0; i < num_local_chares; i++) - { - int index = local_chares[i]; - partition_proxy[index].ckLocal()->complete_operation(); - } + complete_operation(); +} + +void Aggregator::complete_sort_values() +{ + tables.erase(sort_values_opts->table_name); + tables.erase(TEMP_TABLE_OFFSET + next_temp_name++); + delete sort_values_opts; + sort_values_opts = nullptr; + agg_samples_collected = 0; + sort_tables_collected = 0; complete_operation(); } diff --git a/src/partition.hpp b/src/partition.hpp index 4da7074..5ca52d5 100644 --- a/src/partition.hpp +++ b/src/partition.hpp @@ -57,6 +57,20 @@ class GroupByOptions {} }; +class SortValuesOptions +{ +public: + int table_name; + int result_name; + std::vector sort_keys; + + SortValuesOptions(int table_name_, int result_name_, std::vector sort_keys_) + : table_name(table_name_) + , result_name(result_name_) + , sort_keys(sort_keys_) + {} +}; + class PELoad { public: @@ -119,6 +133,12 @@ class Aggregator : public CBase_Aggregator JoinOptions* join_opts; GroupByOptions* groupby_opts; + // for sorting + int agg_samples_collected; + int sort_tables_collected; + std::vector all_samples; + SortValuesOptions* sort_values_opts; + int EPOCH; public: @@ -168,6 +188,14 @@ class Aggregator : public CBase_Aggregator void operation_groupby(char* cmd); + void operation_sort_values(char* cmd); + + void collect_samples(int num_samples, int64_t samples[num_samples]); + + void receive_splitters(int num_splitters, int64_t splitters[num_splitters]); + + void receive_sort_tables(SortTableMsg* msg); + //void operation_barrier(char* cmd); void execute_command(int epoch, int size, char* cmd); @@ -190,6 +218,8 @@ class Aggregator : public CBase_Aggregator void complete_join(); + void complete_sort_values(); + TablePtr local_join(TablePtr &t1, TablePtr &t2, arrow::acero::HashJoinNodeOptions &opts); void partition_table(TablePtr table, int result_name); diff --git a/src/utils.hpp b/src/utils.hpp index 8bc9a0e..416c78c 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -29,7 +29,8 @@ enum class Operation : int Skip = 9, FetchSize = 10, Barrier = 11, - Reduction = 12 + Reduction = 12, + SortValues = 13 }; template