diff --git a/pkg/sql/colexec/group/exec.go b/pkg/sql/colexec/group/exec.go index 700d5aa1ccadd..8c0539a56d334 100644 --- a/pkg/sql/colexec/group/exec.go +++ b/pkg/sql/colexec/group/exec.go @@ -17,9 +17,11 @@ package group import ( "bytes" "fmt" + "slices" "github.com/matrixorigin/matrixone/pkg/common/hashmap" "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/container/hashtable" "github.com/matrixorigin/matrixone/pkg/container/types" "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" "github.com/matrixorigin/matrixone/pkg/sql/plan/function" @@ -63,6 +65,9 @@ func (group *Group) Prepare(proc *process.Process) (err error) { if err = group.prepareGroup(proc); err != nil { return err } + if err := group.initSpiller(proc); err != nil { + return err + } return group.PrepareProjection(proc) } @@ -181,18 +186,26 @@ func (group *Group) getInputBatch(proc *process.Process) (*batch.Batch, error) { // To avoid a single batch being too large, // we split the result as many part of vector, and send them in order. func (group *Group) callToGetFinalResult(proc *process.Process) (*batch.Batch, error) { - group.ctr.result1.CleanLastPopped(proc.Mp()) + group.ctr.finalResults.CleanLastPopped(proc.Mp()) if group.ctr.state == vm.End { return nil, nil } + // If spilling occurred, recall and merge all spilled data first. + if group.ctr.spilled { + if err := group.recallAndMergeSpilledData(proc); err != nil { + return nil, err + } + group.ctr.spilled = false + } + for { if group.ctr.state == vm.Eval { - if group.ctr.result1.IsEmpty() { + if group.ctr.finalResults.IsEmpty() { group.ctr.state = vm.End return nil, nil } - return group.ctr.result1.PopResult(proc.Mp()) + return group.ctr.finalResults.PopResult(proc.Mp()) } res, err := group.getInputBatch(proc) @@ -206,7 +219,7 @@ func (group *Group) callToGetFinalResult(proc *process.Process) (*batch.Batch, e if err = group.generateInitialResult1WithoutGroupBy(proc); err != nil { return nil, err } - group.ctr.result1.ToPopped[0].SetRowCount(1) + group.ctr.finalResults.ToPopped[0].SetRowCount(1) } continue } @@ -218,6 +231,15 @@ func (group *Group) callToGetFinalResult(proc *process.Process) (*batch.Batch, e if err = group.consumeBatchToGetFinalResult(proc, res); err != nil { return nil, err } + + // check for spill + currentSize := group.getSize() + if currentSize > group.ctr.spillThreshold { + if err := group.spillCurrentState(proc); err != nil { + return nil, err + } + } + } } @@ -227,9 +249,9 @@ func (group *Group) generateInitialResult1WithoutGroupBy(proc *process.Process) return err } - group.ctr.result1.InitOnlyAgg(aggexec.GetMinAggregatorsChunkSize(nil, aggs), aggs) - for i := range group.ctr.result1.AggList { - if err = group.ctr.result1.AggList[i].GroupGrow(1); err != nil { + group.ctr.finalResults.InitOnlyAgg(aggexec.GetMinAggregatorsChunkSize(nil, aggs), aggs) + for i := range group.ctr.finalResults.AggList { + if err = group.ctr.finalResults.AggList[i].GroupGrow(1); err != nil { return err } } @@ -246,23 +268,23 @@ func (group *Group) consumeBatchToGetFinalResult( switch group.ctr.mtyp { case H0: // without group by. - if group.ctr.result1.IsEmpty() { + if group.ctr.finalResults.IsEmpty() { if err := group.generateInitialResult1WithoutGroupBy(proc); err != nil { return err } } - group.ctr.result1.ToPopped[0].SetRowCount(1) - for i := range group.ctr.result1.AggList { - if err := group.ctr.result1.AggList[i].BulkFill(0, group.ctr.aggregateEvaluate[i].Vec); err != nil { + group.ctr.finalResults.ToPopped[0].SetRowCount(1) + for i := range group.ctr.finalResults.AggList { + if err := group.ctr.finalResults.AggList[i].BulkFill(0, group.ctr.aggregateEvaluate[i].Vec); err != nil { return err } } default: // with group by. - if group.ctr.result1.IsEmpty() { - err := group.ctr.hr.BuildHashTable(false, group.ctr.mtyp == HStr, group.ctr.keyNullable, group.PreAllocSize) + if group.ctr.finalResults.IsEmpty() { + err := group.ctr.hashMap.BuildHashTable(false, group.ctr.mtyp == HStr, group.ctr.keyNullable, group.PreAllocSize) if err != nil { return err } @@ -271,7 +293,7 @@ func (group *Group) consumeBatchToGetFinalResult( if err != nil { return err } - if err = group.ctr.result1.InitWithGroupBy( + if err = group.ctr.finalResults.InitWithGroupBy( proc.Mp(), aggexec.GetMinAggregatorsChunkSize(group.ctr.groupByEvaluate.Vec, aggs), aggs, group.ctr.groupByEvaluate.Vec, group.PreAllocSize); err != nil { return err @@ -280,21 +302,21 @@ func (group *Group) consumeBatchToGetFinalResult( count := bat.RowCount() more := 0 - aggList := group.ctr.result1.GetAggList() + aggList := group.ctr.finalResults.GetAggList() for i := 0; i < count; i += hashmap.UnitLimit { n := count - i if n > hashmap.UnitLimit { n = hashmap.UnitLimit } - originGroupCount := group.ctr.hr.Hash.GroupCount() - vals, _, err := group.ctr.hr.Itr.Insert(i, n, group.ctr.groupByEvaluate.Vec) + originGroupCount := group.ctr.hashMap.Hash.GroupCount() + vals, _, err := group.ctr.hashMap.Iter.Insert(i, n, group.ctr.groupByEvaluate.Vec) if err != nil { return err } - insertList, _ := group.ctr.hr.GetBinaryInsertList(vals, originGroupCount) + insertList, _ := group.ctr.hashMap.GetBinaryInsertList(vals, originGroupCount) - more, err = group.ctr.result1.AppendBatch(proc.Mp(), group.ctr.groupByEvaluate.Vec, i, insertList) + more, err = group.ctr.finalResults.AppendBatch(proc.Mp(), group.ctr.groupByEvaluate.Vec, i, insertList) if err != nil { return err } @@ -361,7 +383,7 @@ func preExtendAggExecs(execs []aggexec.AggFuncExec, preAllocated uint64) (err er // // this function will be only called when there is one MergeGroup operator was behind. func (group *Group) callToGetIntermediateResult(proc *process.Process) (*batch.Batch, error) { - group.ctr.result2.resetLastPopped() + group.ctr.intermediateResults.resetLastPopped() if group.ctr.state == vm.End { return nil, nil } @@ -371,9 +393,17 @@ func (group *Group) callToGetIntermediateResult(proc *process.Process) (*batch.B return nil, err } + // If spilling occurred, recall and merge all spilled data first. + if group.ctr.spilled { + if err := group.recallAndMergeSpilledData(proc); err != nil { + return nil, err + } + group.ctr.spilled = false // All spilled data merged + } + var input *batch.Batch for { - if group.ctr.state == vm.End { + if group.ctr.state == vm.End && !group.ctr.recalling { return nil, nil } @@ -387,8 +417,7 @@ func (group *Group) callToGetIntermediateResult(proc *process.Process) (*batch.B if group.ctr.isDataSourceEmpty() && len(group.Exprs) == 0 { r.SetRowCount(1) return r, nil - } - if r.RowCount() > 0 { + } else if r.RowCount() > 0 { return r, nil } continue @@ -401,10 +430,10 @@ func (group *Group) callToGetIntermediateResult(proc *process.Process) (*batch.B if next, er := group.consumeBatchToRes(proc, input, r); er != nil { return nil, er } else { - if next { - continue + if !next { + // intermediate result is ready + return r, nil } - return r, nil } } } @@ -412,7 +441,7 @@ func (group *Group) callToGetIntermediateResult(proc *process.Process) (*batch.B func (group *Group) initCtxToGetIntermediateResult( proc *process.Process) (*batch.Batch, error) { - r, err := group.ctr.result2.getResultBatch( + r, err := group.ctr.intermediateResults.getResultBatch( proc, &group.ctr.groupByEvaluate, group.ctr.aggregateEvaluate, group.Aggs) if err != nil { return nil, err @@ -426,7 +455,7 @@ func (group *Group) initCtxToGetIntermediateResult( } } else { allocated := max(min(group.PreAllocSize, uint64(intermediateResultSendActionTrigger)), 0) - if err = group.ctr.hr.BuildHashTable(true, group.ctr.mtyp == HStr, group.ctr.keyNullable, allocated); err != nil { + if err = group.ctr.hashMap.BuildHashTable(true, group.ctr.mtyp == HStr, group.ctr.keyNullable, allocated); err != nil { return nil, err } err = preExtendAggExecs(r.Aggs, allocated) @@ -462,12 +491,12 @@ func (group *Group) consumeBatchToRes( n = hashmap.UnitLimit } - originGroupCount := group.ctr.hr.Hash.GroupCount() - vals, _, err1 := group.ctr.hr.Itr.Insert(i, n, group.ctr.groupByEvaluate.Vec) + originGroupCount := group.ctr.hashMap.Hash.GroupCount() + vals, _, err1 := group.ctr.hashMap.Iter.Insert(i, n, group.ctr.groupByEvaluate.Vec) if err1 != nil { return false, err1 } - insertList, more := group.ctr.hr.GetBinaryInsertList(vals, originGroupCount) + insertList, more := group.ctr.hashMap.GetBinaryInsertList(vals, originGroupCount) cnt := int(more) if cnt > 0 { @@ -492,6 +521,364 @@ func (group *Group) consumeBatchToRes( } } + // check for spill + currentSize := group.getSize() + if currentSize > group.ctr.spillThreshold { + if err := group.spillCurrentState(proc); err != nil { + return false, err + } + // need to re-initialize it for the next iteration or final flush + _, err := group.initCtxToGetIntermediateResult(proc) + if err != nil { + return false, err + } + } + return res.RowCount() < intermediateResultSendActionTrigger, nil } } + +func (group *Group) spillCurrentState(proc *process.Process) (err error) { + // hashmap + var hashmapData []byte = nil + if group.ctr.hashMap.Hash != nil && group.ctr.hashMap.Hash.GroupCount() > 0 { + hashmapData, err = group.ctr.hashMap.Hash.MarshalBinary() + if err != nil { + return err + } + } + + // aggregation states + var aggStatesData [][]byte + if group.NeedEval { + aggStatesData = make([][]byte, len(group.ctr.finalResults.AggList)) + for i, agg := range group.ctr.finalResults.AggList { + aggStatesData[i], err = aggexec.MarshalAggFuncExec(agg) + if err != nil { + return err + } + } + } else { + aggStatesData = make([][]byte, len(group.ctr.intermediateResults.res.Aggs)) + for i, agg := range group.ctr.intermediateResults.res.Aggs { + aggStatesData[i], err = aggexec.MarshalAggFuncExec(agg) + if err != nil { + return err + } + } + } + + // group-by batches (for NeedEval=true, multiple batches in ToPopped) + var groupByBatchesData [][]byte + if group.NeedEval { + if len(group.ctr.finalResults.ToPopped) > 0 { + for _, b := range group.ctr.finalResults.ToPopped { + if b.IsEmpty() { + continue // Skip empty batches + } + batData, marshalErr := b.MarshalBinary() + if marshalErr != nil { + return marshalErr + } + groupByBatchesData = append(groupByBatchesData, batData) + } + if len(groupByBatchesData) == 0 { // All batches were empty, nothing to spill for group-by keys + return err + } + } + + } else { + groupByBatchData, err := group.ctr.intermediateResults.res.MarshalBinary() + if err != nil { // If intermediateResults.res is nil or empty, this will be empty. + return err + } + groupByBatchesData = [][]byte{groupByBatchData} + } + + if err := group.ctr.spiller.spillState(hashmapData, aggStatesData, groupByBatchesData); err != nil { + return err + } + + // clear in-memory state + group.ctr.hashMap.Free0() + if group.NeedEval { + group.ctr.finalResults.Free0(proc.Mp()) + group.ctr.finalResults = GroupResultBuffer{} + } else { + group.ctr.intermediateResults.Free0(proc.Mp()) + group.ctr.intermediateResults = GroupResultNoneBlock{} + } + return nil +} + +func (group *Group) recallAndMergeSpilledData(proc *process.Process) error { + group.ctr.recalling = true + defer func() { + group.ctr.recalling = false + group.ctr.spilled = false + }() + + if group.ctr.hashMap.Hash == nil { + if err := group.ctr.hashMap.BuildHashTable(false, group.ctr.mtyp == HStr, group.ctr.keyNullable, 0); err != nil { + return err + } + } + + var currentAggs []aggexec.AggFuncExec + var currentGroupByBatch *batch.Batch + + if group.NeedEval { + if group.ctr.finalResults.IsEmpty() { + aggs, err := group.generateAggExec(proc) + if err != nil { + return err + } + if err = group.ctr.finalResults.InitWithGroupBy( + proc.Mp(), + aggexec.GetMinAggregatorsChunkSize(group.ctr.groupByEvaluate.Vec, aggs), aggs, group.ctr.groupByEvaluate.Vec, 0); err != nil { + return err + } + } + currentAggs = group.ctr.finalResults.AggList + + } else { + if group.ctr.intermediateResults.res == nil { + r, err := group.ctr.intermediateResults.getResultBatch( + proc, &group.ctr.groupByEvaluate, group.ctr.aggregateEvaluate, group.Aggs) + if err != nil { + return err + } + currentGroupByBatch = r + } else { + currentGroupByBatch = group.ctr.intermediateResults.res + } + currentAggs = currentGroupByBatch.Aggs + } + + spillFiles := slices.Clone(group.ctr.spiller.getSpillFiles()) + for _, filePath := range spillFiles { + hashmapData, aggStatesData, groupByBatchData, err := group.ctr.spiller.recallState(filePath) + if err != nil { + return err + } + + var recalledHM hashmap.HashMap + if group.ctr.mtyp == HStr { + recalledHM, err = hashmap.NewStrMap(group.ctr.keyNullable) + } else { + recalledHM, err = hashmap.NewIntHashMap(group.ctr.keyNullable) + } + if err != nil { + return err + } + // Pass the default hashmap allocator for deserialization. + if err := recalledHM.UnmarshalBinary(hashmapData, hashtable.DefaultAllocator()); err != nil { + recalledHM.Free() + return err + } + + recalledAggs := make([]aggexec.AggFuncExec, len(aggStatesData)) + aggMemoryManager := aggexec.NewSimpleAggMemoryManager(proc.Mp()) + for i, aggData := range aggStatesData { + recalledAggs[i], err = aggexec.UnmarshalAggFuncExec(aggMemoryManager, aggData) + if err != nil { + for j := 0; j < i; j++ { + recalledAggs[j].Free() + } + recalledHM.Free() + return err + } + } + + recalledGroupByBatches := make([]*batch.Batch, len(groupByBatchData)) + for i, batData := range groupByBatchData { + recalledBat := batch.NewOffHeapWithSize(len(group.ctr.groupByEvaluate.Vec)) + // Attributes are consistent, no need to unmarshal them. + // recalledBat.Attrs = group.ctr.groupByEvaluate.Vec.Attrs + if err := recalledBat.UnmarshalBinaryWithAnyMp(batData, proc.Mp()); err != nil { + for _, agg := range recalledAggs { + agg.Free() + } + recalledHM.Free() + for j := 0; j < i; j++ { // Clean up already unmarshaled batches if error occurs + recalledGroupByBatches[j].Clean(proc.Mp()) + } + recalledBat.Clean(proc.Mp()) + return err + } + recalledGroupByBatches[i] = recalledBat + } + + // Process each recalled group-by batch + for _, recalledBat := range recalledGroupByBatches { + if recalledBat.IsEmpty() { + recalledBat.Clean(proc.Mp()) + continue + } + + // Bulk insert recalled group-by keys into the current hashmap + oldCurrentGroupCount := group.ctr.hashMap.Hash.GroupCount() + vals, _, err := group.ctr.hashMap.Iter.Insert(0, recalledBat.RowCount(), recalledBat.Vecs) + if err != nil { + for _, agg := range recalledAggs { + agg.Free() + } + recalledHM.Free() + recalledBat.Clean(proc.Mp()) // Clean current batch + // Clean remaining batches + for _, bat := range recalledGroupByBatches { + if bat != recalledBat { + bat.Clean(proc.Mp()) + } + } + return err + } + + // Identify newly added groups and map recalled group IDs to current group IDs + // This map is specific to the current recalled batch's original group IDs + recalledToCurrentGroupMap := make([]uint64, recalledBat.RowCount()+1) // +1 for 1-based indexing + currentBatchNewGroupsSels := make([]int32, 0) + + for i := 0; i < recalledBat.RowCount(); i++ { + newID := vals[i] + recalledToCurrentGroupMap[uint64(i+1)] = newID + + if newID > oldCurrentGroupCount { + currentBatchNewGroupsSels = append(currentBatchNewGroupsSels, int32(i)) + } + } + + if len(currentBatchNewGroupsSels) > 0 { + // If there are new groups, append their data to the current group-by batch and grow aggregators + if group.NeedEval { + // Create a new batch containing only the newly added group-by keys from this recalled batch + newGroupByKeysBatch := batch.NewOffHeapWithSize(len(recalledBat.Vecs)) + for i, vec := range recalledBat.Vecs { + newVec, err := vec.CloneWindow(0, recalledBat.RowCount(), proc.Mp()) + if err != nil { + newGroupByKeysBatch.Clean(proc.Mp()) + recalledHM.Free() + for _, agg := range recalledAggs { + agg.Free() + } + recalledBat.Clean(proc.Mp()) + return err + } + newVec.Shrink(int32SliceToInt64(currentBatchNewGroupsSels), false) + newGroupByKeysBatch.SetVector(int32(i), newVec) + } + newGroupByKeysBatch.SetRowCount(len(currentBatchNewGroupsSels)) + + // Push the batch containing only new group-by keys to finalResults.ToPopped + if err := group.ctr.finalResults.AppendRecalledBatches(proc.Mp(), []*batch.Batch{newGroupByKeysBatch}); err != nil { + recalledHM.Free() + for _, agg := range recalledAggs { + agg.Free() + } + newGroupByKeysBatch.Clean(proc.Mp()) + recalledBat.Clean(proc.Mp()) + return err + } + + // Grow aggregators for newly added groups + for _, agg := range currentAggs { + if err := agg.GroupGrow(len(currentBatchNewGroupsSels)); err != nil { + recalledHM.Free() + for _, agg := range recalledAggs { + agg.Free() + } + newGroupByKeysBatch.Clean(proc.Mp()) + recalledBat.Clean(proc.Mp()) + return err + } + } + } else { // !group.NeedEval + // For intermediate results, we need to append the new group-by keys to the current batch. + // The `Union` method on batch appends selected rows. + // We need to ensure the vectors in currentGroupByBatch are extensible. + if err := currentGroupByBatch.Union(recalledBat, int32SliceToInt64(currentBatchNewGroupsSels), proc.Mp()); err != nil { // This union will append new rows to currentGroupByBatch.Vecs + recalledHM.Free() + for _, agg := range recalledAggs { + agg.Free() + } + recalledBat.Clean(proc.Mp()) + return err + } + for _, agg := range currentAggs { + if err := agg.GroupGrow(len(currentBatchNewGroupsSels)); err != nil { + recalledHM.Free() + for _, agg := range recalledAggs { + agg.Free() + } + recalledBat.Clean(proc.Mp()) + return err + } + + } + } + } + + // Merge aggregation states for the current recalled batch + // The BatchMerge function expects the `groups` parameter to be the *new* group IDs. + // recalledToCurrentGroupMap[1:] provides this mapping for the original recalled group IDs. + for aggIdx := range currentAggs { + recalledAgg := recalledAggs[aggIdx] + if err := currentAggs[aggIdx].BatchMerge(recalledAgg, 0, recalledToCurrentGroupMap[1:]); err != nil { + for j := aggIdx; j < len(currentAggs); j++ { + recalledAggs[j].Free() + } + recalledHM.Free() + recalledBat.Clean(proc.Mp()) + return err + } + } + recalledBat.Clean(proc.Mp()) // Clean the current recalled batch after processing + } + + recalledHM.Free() + for _, agg := range recalledAggs { + agg.Free() + } + if err := group.ctr.spiller.DeleteFile(filePath); err != nil { + return err + } + + } + + return nil +} + +func (group *Group) getSize() int64 { + var size int64 + + // Hash table size + if group.ctr.hashMap.Hash != nil { + size += group.ctr.hashMap.Hash.Size() + } + + // Aggregation results size + if group.NeedEval { + for _, bat := range group.ctr.finalResults.ToPopped { + if bat != nil { + size += int64(bat.Allocated()) + } + } + for _, agg := range group.ctr.finalResults.AggList { + if agg != nil { + size += agg.Size() + } + } + + } else { + if group.ctr.intermediateResults.res != nil { + size += int64(group.ctr.intermediateResults.res.Allocated()) + for _, agg := range group.ctr.intermediateResults.res.Aggs { + if agg != nil { + size += agg.Size() + } + } + } + } + + return size +} diff --git a/pkg/sql/colexec/group/exec_test.go b/pkg/sql/colexec/group/exec_test.go index 5e53f82e8aef2..b7d3c3cbc7b83 100644 --- a/pkg/sql/colexec/group/exec_test.go +++ b/pkg/sql/colexec/group/exec_test.go @@ -15,6 +15,8 @@ package group import ( + "testing" + "github.com/matrixorigin/matrixone/pkg/common/mpool" "github.com/matrixorigin/matrixone/pkg/container/batch" "github.com/matrixorigin/matrixone/pkg/container/types" @@ -25,7 +27,6 @@ import ( "github.com/matrixorigin/matrixone/pkg/testutil" "github.com/matrixorigin/matrixone/pkg/vm" "github.com/stretchr/testify/require" - "testing" ) // hackAggExecToTest 是一个不带任何逻辑的AggExec,主要用于单测中检查各种接口的调用次数。 @@ -42,6 +43,12 @@ type hackAggExecToTest struct { isFree bool } +var _ aggexec.AggFuncExec = new(hackAggExecToTest) + +func (h *hackAggExecToTest) Size() int64 { + return 0 +} + func (h *hackAggExecToTest) GetOptResult() aggexec.SplitResult { return nil } diff --git a/pkg/sql/colexec/group/execctx.go b/pkg/sql/colexec/group/execctx.go index 91f1824c30aa0..1f0b48630ced1 100644 --- a/pkg/sql/colexec/group/execctx.go +++ b/pkg/sql/colexec/group/execctx.go @@ -23,17 +23,17 @@ import ( "github.com/matrixorigin/matrixone/pkg/vm/process" ) -type ResHashRelated struct { +type HashMap struct { Hash hashmap.HashMap - Itr hashmap.Iterator + Iter hashmap.Iterator inserted []uint8 } -func (hr *ResHashRelated) IsEmpty() bool { - return hr.Hash == nil || hr.Itr == nil +func (hr *HashMap) IsEmpty() bool { + return hr.Hash == nil || hr.Iter == nil } -func (hr *ResHashRelated) BuildHashTable( +func (hr *HashMap) BuildHashTable( rebuild bool, isStrHash bool, keyNullable bool, preAllocated uint64) error { @@ -55,10 +55,10 @@ func (hr *ResHashRelated) BuildHashTable( } hr.Hash = h - if hr.Itr == nil { - hr.Itr = h.NewIterator() + if hr.Iter == nil { + hr.Iter = h.NewIterator() } else { - hashmap.IteratorChangeOwner(hr.Itr, hr.Hash) + hashmap.IteratorChangeOwner(hr.Iter, hr.Hash) } if preAllocated > 0 { if err = h.PreAlloc(preAllocated); err != nil { @@ -74,10 +74,10 @@ func (hr *ResHashRelated) BuildHashTable( } hr.Hash = h - if hr.Itr == nil { - hr.Itr = h.NewIterator() + if hr.Iter == nil { + hr.Iter = h.NewIterator() } else { - hashmap.IteratorChangeOwner(hr.Itr, hr.Hash) + hashmap.IteratorChangeOwner(hr.Iter, hr.Hash) } if preAllocated > 0 { if err = h.PreAlloc(preAllocated); err != nil { @@ -87,7 +87,7 @@ func (hr *ResHashRelated) BuildHashTable( return nil } -func (hr *ResHashRelated) GetBinaryInsertList(vals []uint64, before uint64) (insertList []uint8, insertCount uint64) { +func (hr *HashMap) GetBinaryInsertList(vals []uint64, before uint64) (insertList []uint8, insertCount uint64) { if cap(hr.inserted) < len(vals) { hr.inserted = make([]uint8, len(vals)) } else { @@ -108,7 +108,7 @@ func (hr *ResHashRelated) GetBinaryInsertList(vals []uint64, before uint64) (ins return hr.inserted, insertCount } -func (hr *ResHashRelated) Free0() { +func (hr *HashMap) Free0() { if hr.Hash != nil { hr.Hash.Free() hr.Hash = nil @@ -169,6 +169,50 @@ func (buf *GroupResultBuffer) InitWithGroupBy( return preExtendAggExecs(buf.AggList, preAllocated) } +// distributeBatchData is a helper function to append data from inBatch to the buffer, +// potentially creating new batches if the current last batch is full. +// It does not take ownership of inBatch (i.e., it does not clean inBatch). +func (buf *GroupResultBuffer) distributeBatchData(mpool *mpool.MPool, inBatch *batch.Batch) error { + currentInputRow := 0 + inputRowCount := inBatch.RowCount() + + for currentInputRow < inputRowCount { + var targetBat *batch.Batch + lastBatIdx := len(buf.ToPopped) - 1 + + if buf.ToPopped[lastBatIdx].RowCount() < buf.ChunkSize { + targetBat = buf.ToPopped[lastBatIdx] + } else { + targetBat = getInitialBatchWithSameTypeVecs(inBatch.Vecs) + buf.ToPopped = append(buf.ToPopped, targetBat) + } + rowsToFill := min(inputRowCount-currentInputRow, buf.ChunkSize-targetBat.RowCount()) + if err := targetBat.UnionWindow(inBatch, currentInputRow, rowsToFill, mpool); err != nil { + return err + } + currentInputRow += rowsToFill + } + return nil +} + +// AppendRecalledBatches appends a list of recalled batches to the GroupResultBuffer, +// attempting to fill the last existing batch first. +// It takes ownership of the input recalledBatches and cleans them after processing. +func (buf *GroupResultBuffer) AppendRecalledBatches(mpool *mpool.MPool, recalledBatches []*batch.Batch) error { + if len(recalledBatches) == 0 { + return nil + } + // If the last batch in ToPopped is not full, try to fill it with the first recalled batch. + if len(buf.ToPopped) > 0 && buf.ToPopped[len(buf.ToPopped)-1].RowCount() < buf.ChunkSize { + if err := buf.distributeBatchData(mpool, recalledBatches[0]); err != nil { + return err + } + recalledBatches = recalledBatches[1:] + } + buf.ToPopped = append(buf.ToPopped, recalledBatches...) + return nil +} + func (buf *GroupResultBuffer) InitWithBatch(chunkSize int, aggList []aggexec.AggFuncExec, vecExampleBatch *batch.Batch) { aggexec.SyncAggregatorsToChunkSize(aggList, chunkSize) diff --git a/pkg/sql/colexec/group/group_test.go b/pkg/sql/colexec/group/group_test.go new file mode 100644 index 0000000000000..377b47f6c987d --- /dev/null +++ b/pkg/sql/colexec/group/group_test.go @@ -0,0 +1,124 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "math/rand/v2" + "testing" + + "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/pb/plan" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/value_scan" + "github.com/matrixorigin/matrixone/pkg/testutil" + "github.com/matrixorigin/matrixone/pkg/vm" + "github.com/stretchr/testify/require" +) + +func TestGroup_CountStar(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + mp := proc.Mp() + before := mp.CurrNB() + + numRows := 1024 + numBatches := 1024 + + var allBatches []*batch.Batch + for i := 0; i < numBatches; i++ { + inputValues := make([]int64, numRows) + for i := 0; i < numRows; i++ { + inputValues[i] = int64(i + 1) + } + rand.Shuffle(len(inputValues), func(i, j int) { + inputValues[i], inputValues[j] = inputValues[j], inputValues[i] + }) + + inputVec := testutil.NewInt64Vector(numRows, types.T_int64.ToType(), mp, false, inputValues) + inputBatch := batch.NewWithSize(1) + inputBatch.Vecs[0] = inputVec + inputBatch.SetRowCount(numRows) + + allBatches = append(allBatches, inputBatch) + } + + vscan := value_scan.NewArgument() + vscan.Batchs = allBatches + + g := &Group{ + OperatorBase: vm.OperatorBase{}, + NeedEval: true, + Exprs: []*plan.Expr{newColumnExpression(0)}, + GroupingFlag: []bool{true}, + Aggs: []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{ + newColumnExpression(0), + }, + nil, + ), + }, + } + g.AppendChild(vscan) + + require.NoError(t, vscan.Prepare(proc)) + require.NoError(t, g.Prepare(proc)) + + g.ctr.spillThreshold = 8 * 1024 + + var outputBatch *batch.Batch + outputCount := 0 + for { + r, err := g.Call(proc) + require.NoError(t, err) + if r.Batch == nil { + break + } + outputCount++ + outputBatch, err = r.Batch.Dup(proc.Mp()) + require.NoError(t, err) + require.Equal(t, 1, outputCount) + } + + require.NotNil(t, outputBatch) + require.Equal(t, 2, len(outputBatch.Vecs)) + require.Equal(t, numRows, outputBatch.RowCount()) + require.Equal(t, 0, len(outputBatch.Aggs)) + + outputValues := vector.MustFixedColNoTypeCheck[int64](outputBatch.Vecs[0]) + countValues := vector.MustFixedColNoTypeCheck[int64](outputBatch.Vecs[1]) + require.Equal(t, numRows, len(outputValues)) + require.Equal(t, numRows, len(countValues)) + + for _, v := range countValues { + require.Equal(t, int64(numBatches), v, "Count for each unique group should be 1") + } + + outputMap := make(map[int64]int64) + for i, v := range outputValues { + outputMap[v] = countValues[i] + } + + if outputBatch != nil { + outputBatch.Clean(proc.Mp()) + } + g.Free(proc, false, nil) + vscan.Free(proc, false, nil) + require.Equal(t, before, mp.CurrNB()) +} diff --git a/pkg/sql/colexec/group/spill.go b/pkg/sql/colexec/group/spill.go new file mode 100644 index 0000000000000..c88115da9248f --- /dev/null +++ b/pkg/sql/colexec/group/spill.go @@ -0,0 +1,227 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "slices" + "sync/atomic" + + "github.com/matrixorigin/matrixone/pkg/common/moerr" + "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/fileservice" + "github.com/matrixorigin/matrixone/pkg/vm/process" +) + +var ( + // spillFileCounter is used to generate unique file names for spilled data. + spillFileCounter atomic.Uint64 +) + +// Spiller manages spilling data to disk for the group operator. +type Spiller struct { + fs fileservice.MutableFileService + spillFiles []string + proc *process.Process +} + +func NewSpiller(proc *process.Process) (*Spiller, error) { + fs, err := proc.GetSpillFileService() + if err != nil { + return nil, err + } + return &Spiller{ + fs: fs, + proc: proc, + }, nil +} + +// spill writes a batch to a new temporary spill file. +func (s *Spiller) spillBatch(bat *batch.Batch) error { + if bat.IsEmpty() { + return nil + } + + filePath := fmt.Sprintf("group_spill_%s_%d.bin", s.proc.QueryId(), spillFileCounter.Add(1)) + + data, err := bat.MarshalBinary() + if err != nil { + return moerr.NewInternalErrorf(s.proc.Ctx, "failed to marshal batch for spilling: %v", err) + } + + vec := fileservice.IOVector{ + FilePath: filePath, + Entries: []fileservice.IOEntry{ + { + Offset: 0, + Size: int64(len(data)), + Data: data, + }, + }, + } + + if err = s.fs.Write(s.proc.Ctx, vec); err != nil { + return moerr.NewInternalErrorf(s.proc.Ctx, "failed to write spill file %s: %v", filePath, err) + } + + s.spillFiles = append(s.spillFiles, filePath) + return nil +} + +// getReaders returns a list of readers for all spilled files. +func (s *Spiller) getSpillFiles() []string { + return s.spillFiles +} + +// clean deletes all temporary spill files. +func (s *Spiller) clean() error { + var lastErr error + for _, filePath := range s.spillFiles { + if err := s.fs.Delete(s.proc.Ctx, filePath); err != nil { + lastErr = moerr.NewInternalErrorf(s.proc.Ctx, "failed to delete spill file %s: %v", filePath, err) + } + } + s.spillFiles = nil + return lastErr +} + +// spillState writes the serialized hashmap, aggregation states, and group-by batch to a new temporary spill file. +func (s *Spiller) spillState(hashmapData []byte, aggStates [][]byte, groupByBatchesData [][]byte) error { + filePath := fmt.Sprintf("group_spill_state_%s_%d.bin", s.proc.QueryId(), spillFileCounter.Add(1)) + + var buffer bytes.Buffer + // Write lengths of each component + if err := binary.Write(&buffer, binary.LittleEndian, uint64(len(hashmapData))); err != nil { + return moerr.NewInternalErrorf(s.proc.Ctx, "failed to write hashmapData length for spilling: %v", err) + } + if err := binary.Write(&buffer, binary.LittleEndian, uint64(len(aggStates))); err != nil { + return moerr.NewInternalErrorf(s.proc.Ctx, "failed to write aggStates count for spilling: %v", err) + } + for _, aggData := range aggStates { + if err := binary.Write(&buffer, binary.LittleEndian, uint64(len(aggData))); err != nil { + return moerr.NewInternalErrorf(s.proc.Ctx, "failed to write aggData length for spilling: %v", err) + } + } + if err := binary.Write(&buffer, binary.LittleEndian, uint64(len(groupByBatchesData))); err != nil { + return moerr.NewInternalErrorf(s.proc.Ctx, "failed to write groupByBatchesData count for spilling: %v", err) + } + for _, batchData := range groupByBatchesData { + if err := binary.Write(&buffer, binary.LittleEndian, uint64(len(batchData))); err != nil { + return moerr.NewInternalErrorf(s.proc.Ctx, "failed to write groupByBatchData length for spilling: %v", err) + } + } + + // Write data of each component + buffer.Write(hashmapData) + for _, aggData := range aggStates { + buffer.Write(aggData) + } + for _, batchData := range groupByBatchesData { + buffer.Write(batchData) + } + + vec := fileservice.IOVector{ + FilePath: filePath, + Entries: []fileservice.IOEntry{ + { + Offset: 0, + Size: int64(buffer.Len()), + Data: buffer.Bytes(), + }, + }, + } + + if err := s.fs.Write(s.proc.Ctx, vec); err != nil { + return moerr.NewInternalErrorf(s.proc.Ctx, "failed to write spill file %s: %v", filePath, err) + } + + s.spillFiles = append(s.spillFiles, filePath) + return nil +} + +// recallState reads a spill file and returns the serialized hashmap, aggregation states, and group-by batch. +func (s *Spiller) recallState(filePath string) ([]byte, [][]byte, [][]byte, error) { + vec := fileservice.IOVector{ + FilePath: filePath, + Entries: []fileservice.IOEntry{{Offset: 0, Size: -1}}, // Read entire file + } + + if err := s.fs.Read(s.proc.Ctx, &vec); err != nil { + return nil, nil, nil, moerr.NewInternalErrorf(s.proc.Ctx, "failed to read spill file %s: %v", filePath, err) + } + + reader := bytes.NewReader(vec.Entries[0].Data) + + var hashmapLen, aggStatesCount, groupByBatchesCount uint64 + if err := binary.Read(reader, binary.LittleEndian, &hashmapLen); err != nil { + return nil, nil, nil, moerr.NewInternalErrorf(s.proc.Ctx, "failed to read hashmap length from spill file: %v", err) + } + if err := binary.Read(reader, binary.LittleEndian, &aggStatesCount); err != nil { + return nil, nil, nil, moerr.NewInternalErrorf(s.proc.Ctx, "failed to read aggStates count from spill file: %v", err) + } + + aggLens := make([]uint64, aggStatesCount) + for i := range aggLens { + if err := binary.Read(reader, binary.LittleEndian, &aggLens[i]); err != nil { + return nil, nil, nil, moerr.NewInternalErrorf(s.proc.Ctx, "failed to read aggData length from spill file: %v", err) + } + } + if err := binary.Read(reader, binary.LittleEndian, &groupByBatchesCount); err != nil { + return nil, nil, nil, moerr.NewInternalErrorf(s.proc.Ctx, "failed to read groupByBatches count from spill file: %v", err) + } + groupByBatchLens := make([]uint64, groupByBatchesCount) + for i := range groupByBatchLens { + if err := binary.Read(reader, binary.LittleEndian, &groupByBatchLens[i]); err != nil { + return nil, nil, nil, moerr.NewInternalErrorf(s.proc.Ctx, "failed to read groupByBatchData length from spill file: %v", err) + } + } + + hashmapData := make([]byte, hashmapLen) + if _, err := io.ReadFull(reader, hashmapData); err != nil { + return nil, nil, nil, moerr.NewInternalErrorf(s.proc.Ctx, "failed to read hashmap data from spill file: %v", err) + } + + aggStates := make([][]byte, aggStatesCount) + for i, l := range aggLens { + aggStates[i] = make([]byte, l) + if _, err := io.ReadFull(reader, aggStates[i]); err != nil { + return nil, nil, nil, moerr.NewInternalErrorf(s.proc.Ctx, "failed to read aggStates data from spill file: %v", err) + } + } + + groupByBatchesData := make([][]byte, groupByBatchesCount) + for i, l := range groupByBatchLens { + groupByBatchesData[i] = make([]byte, l) + if _, err := io.ReadFull(reader, groupByBatchesData[i]); err != nil { + return nil, nil, nil, moerr.NewInternalErrorf(s.proc.Ctx, "failed to read groupByBatch data from spill file: %v", err) + } + } + return hashmapData, aggStates, groupByBatchesData, nil +} + +// DeleteFile deletes a specific temporary spill file. +func (s *Spiller) DeleteFile(filePath string) error { + if err := s.fs.Delete(s.proc.Ctx, filePath); err != nil { + return moerr.NewInternalErrorf(s.proc.Ctx, "failed to delete spill file %s: %v", filePath, err) + } + // Remove the file from the list of spilled files + s.spillFiles = slices.DeleteFunc(s.spillFiles, func(f string) bool { + return f == filePath + }) + return nil +} diff --git a/pkg/sql/colexec/group/types.go b/pkg/sql/colexec/group/types.go index 0243e4cda69d2..9e83fd59953eb 100644 --- a/pkg/sql/colexec/group/types.go +++ b/pkg/sql/colexec/group/types.go @@ -20,11 +20,13 @@ import ( "github.com/matrixorigin/matrixone/pkg/container/batch" "github.com/matrixorigin/matrixone/pkg/container/types" "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/logutil" "github.com/matrixorigin/matrixone/pkg/pb/plan" "github.com/matrixorigin/matrixone/pkg/sql/colexec" "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" "github.com/matrixorigin/matrixone/pkg/vm" "github.com/matrixorigin/matrixone/pkg/vm/process" + "go.uber.org/zap" ) const ( @@ -147,7 +149,7 @@ type container struct { dataSourceIsEmpty bool // hash. - hr ResHashRelated + hashMap HashMap mtyp int keyWidth int keyNullable bool @@ -158,9 +160,16 @@ type container struct { aggregateEvaluate []ExprEvalVector // result if NeedEval is true. - result1 GroupResultBuffer + finalResults GroupResultBuffer // result if NeedEval is false. - result2 GroupResultNoneBlock + intermediateResults GroupResultNoneBlock + + spiller *Spiller + spilled bool + spillThreshold int64 + // recalling indicates that the operator is currently recalling spilled data. + // During this phase, no new data should be consumed from the child operator. + recalling bool } func (ctr *container) isDataSourceEmpty() bool { @@ -178,6 +187,18 @@ func (group *Group) Free(proc *process.Process, _ bool, _ error) { func (group *Group) Reset(proc *process.Process, pipelineFailed bool, err error) { group.freeCannotReuse(proc.Mp()) + // clean up spill files + if group.ctr.spiller != nil { + if spillErr := group.ctr.spiller.clean(); spillErr != nil { + logutil.Error("failed to clean up spill files during reset", zap.Error(spillErr)) + } + // After cleaning, the spiller is no longer needed for this operator instance. + // It will be re-initialized if the operator is prepared again. + group.ctr.spiller = nil + } + group.ctr.spilled = false + group.ctr.recalling = false + group.ctr.groupByEvaluate.ResetForNextQuery() for i := range group.ctr.aggregateEvaluate { group.ctr.aggregateEvaluate[i].ResetForNextQuery() @@ -186,9 +207,24 @@ func (group *Group) Reset(proc *process.Process, pipelineFailed bool, err error) } func (group *Group) freeCannotReuse(mp *mpool.MPool) { - group.ctr.hr.Free0() - group.ctr.result1.Free0(mp) - group.ctr.result2.Free0(mp) + group.ctr.hashMap.Free0() + group.ctr.finalResults.Free0(mp) + group.ctr.intermediateResults.Free0(mp) + if group.ctr.spiller != nil { + group.ctr.spiller.clean() + } + group.ctr.spilled = false + group.ctr.spiller = nil + group.ctr.recalling = false +} + +func (group *Group) initSpiller(proc *process.Process) (err error) { + group.ctr.spiller, err = NewSpiller(proc) + if err != nil { + return err + } + group.ctr.spillThreshold = proc.Mp().Cap() / 2 //TODO configurable + return nil } func (ctr *container) freeAggEvaluate() { diff --git a/pkg/sql/colexec/group/utils.go b/pkg/sql/colexec/group/utils.go new file mode 100644 index 0000000000000..894215b58db82 --- /dev/null +++ b/pkg/sql/colexec/group/utils.go @@ -0,0 +1,23 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +func int32SliceToInt64(s []int32) []int64 { + ret := make([]int64, len(s)) + for i, v := range s { + ret[i] = int64(v) + } + return ret +} diff --git a/pkg/sql/colexec/mergegroup/exec.go b/pkg/sql/colexec/mergegroup/exec.go index cd91238966e3a..0d6954b81d135 100644 --- a/pkg/sql/colexec/mergegroup/exec.go +++ b/pkg/sql/colexec/mergegroup/exec.go @@ -178,7 +178,7 @@ func (mergeGroup *MergeGroup) consumeBatch(proc *process.Process, b *batch.Batch } origin := mergeGroup.ctr.hr.Hash.GroupCount() - vals, _, err := mergeGroup.ctr.hr.Itr.Insert(i, n, b.Vecs) + vals, _, err := mergeGroup.ctr.hr.Iter.Insert(i, n, b.Vecs) if err != nil { return err } diff --git a/pkg/sql/colexec/mergegroup/types.go b/pkg/sql/colexec/mergegroup/types.go index 68490a86f5f8e..4781ba4981533 100644 --- a/pkg/sql/colexec/mergegroup/types.go +++ b/pkg/sql/colexec/mergegroup/types.go @@ -49,7 +49,7 @@ type container struct { state vm.CtrState // hash. - hr group.ResHashRelated + hr group.HashMap // res. result group.GroupResultBuffer } diff --git a/pkg/testutil/util_compare.go b/pkg/testutil/util_compare.go index 52db2ae175b03..dc4e47b968d44 100644 --- a/pkg/testutil/util_compare.go +++ b/pkg/testutil/util_compare.go @@ -17,9 +17,13 @@ package testutil import ( "bytes" "reflect" + "testing" + "github.com/matrixorigin/matrixone/pkg/container/batch" "github.com/matrixorigin/matrixone/pkg/container/nulls" "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" + "github.com/stretchr/testify/require" ) func CompareVectors(expected *vector.Vector, got *vector.Vector) bool { @@ -85,3 +89,36 @@ func CompareVectors(expected *vector.Vector, got *vector.Vector) bool { } } } + +// CompareBatches compares two batches for deep equality. +func CompareBatches(t *testing.T, expected, actual *batch.Batch) { + if expected == nil && actual == nil { + return + } + if expected == nil || actual == nil { + t.Fatalf("one batch is nil, the other is not. Expected: %!v(MISSING), Actual: %!v(MISSING)", expected, actual) + } + + require.Equal(t, expected.RowCount(), actual.RowCount(), "row count mismatch") + + require.Equal(t, len(expected.Vecs), len(actual.Vecs), "vector count mismatch") + + require.Equal(t, len(expected.Attrs), len(actual.Attrs), "attribute count mismatch") + + for i := range expected.Attrs { + require.Equal(t, expected.Attrs[i], actual.Attrs[i], "attribute name mismatch at index %!d(MISSING)", i) + } + + for i := range expected.Vecs { + require.True(t, CompareVectors(expected.Vecs[i], actual.Vecs[i]), "vector content mismatch at index %!d(MISSING)", i) + } + + require.Equal(t, len(expected.Aggs), len(actual.Aggs), "aggregator count mismatch") + for i := range expected.Aggs { + expectedBytes, err := aggexec.MarshalAggFuncExec(expected.Aggs[i]) + require.NoError(t, err) + actualBytes, err := aggexec.MarshalAggFuncExec(actual.Aggs[i]) + require.NoError(t, err) + require.Equal(t, expectedBytes, actualBytes, "aggregator state mismatch at index %!d(MISSING)", i) + } +}