diff --git a/.gitmodules b/.gitmodules index a1a876d8b..1489ea541 100644 --- a/.gitmodules +++ b/.gitmodules @@ -20,3 +20,6 @@ [submodule "src/3rd_party/simple-websocket-server"] path = src/3rd_party/simple-websocket-server url = https://github.com/marian-nmt/Simple-WebSocket-Server +[submodule "src/3rd_party/cub"] + path = src/3rd_party/cub + url = https://github.com/NVIDIA/cub diff --git a/CHANGELOG.md b/CHANGELOG.md index b0d05e954..656dd8e1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added + +- Includes cub as a dependency +- Replaces the topK implementation in nth_element.cu and topk.cu - Local/global sharding with MPI training via `--sharding local` - fp16 support for factors. - Correct training with fp16 via `--fp16`. diff --git a/LICENSE.md b/LICENSE.md index 878908214..269f4085b 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -3,6 +3,8 @@ MIT License Copyright (c) 2016 Marcin Junczys-Dowmunt, the University of Edinburgh, Adam Mickiewicz University +Copyright (c) 2020 NVIDIA Corporation + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights diff --git a/src/3rd_party/cub b/src/3rd_party/cub new file mode 160000 index 000000000..52d58a889 --- /dev/null +++ b/src/3rd_party/cub @@ -0,0 +1 @@ +Subproject commit 52d58a88904da39c374e44a6a8ae0e4dcca5b71a diff --git a/src/3rd_party/topk.cuh b/src/3rd_party/topk.cuh new file mode 100644 index 000000000..d0a7d2416 --- /dev/null +++ b/src/3rd_party/topk.cuh @@ -0,0 +1,485 @@ +/* +* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +* This code is modified from the topk implementation in NVIDIA's faster +* transformer repository. The original source code files can be found here: +* +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/FasterTransformer/v3.0/fastertransformer/cuda/topk_kernels.cu +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/FasterTransformer/v3.0/fastertransformer/cuda/topk_kernels.cuh +*/ + +#pragma once +#include +#include +#include +#if CUDA_VERSION >= 11000 +#include +#else +#include "cub/cub/cub.cuh" +#endif + +#define MAX_BLOCKS_PER_ITEM 8 + +/// A struct to access the infinity constant on device based on type. +template +struct FpInfinity; + +/// Specialization of FpInfinity for float +template <> +struct FpInfinity { + static __host__ __device__ __forceinline__ float infinity() { + return INFINITY; + } +}; + +/// Specialization of FpInfinity for half +template <> +struct FpInfinity<__half> { + static __host__ __device__ __forceinline__ __half infinity() { + return __float2half(INFINITY); + } +}; + +/// A struct used to track the largest value along with the index at which +/// it occurs at when performing the topk reduction. + +/// It is assumed that IndexType is an integral type and T is a floating point +/// type. +template +struct TopK { + // The index of the largest/smallest value in the large + IndexType index = 0; + // The largest/smallest value encountered + T value = -FpInfinity::infinity(); + + /// Updates the value and index in the topk struct if elem is larger than the current + /// value field of the struct. This is intended to be used during the initial reduction + /// before we reduce across a block to ensure all threads in the block have the largest + /// values within the block's range. + __device__ __forceinline__ void updateIfLarger(T elem, IndexType elem_id) { + if (elem > value) { + value = elem; + index = elem_id; + } + } + + /// Updates the value and index in the topk struct if elem is smaller than the current + /// value field of the struct. This is intended to be used during the initial reduction + /// before we reduce across a block to ensure all threads in the block have the smallest + /// values within the block's range. + __device__ __forceinline__ void updateIfSmaller(T elem, IndexType elem_id) { + if (elem < value) { + value = elem; + index = elem_id; + } + } + + /// Initializes the value and index fields of the topk struct before starting a reduction. + /// If the descendingOrder flag is true, the value starts at negative infinity so that we + /// store the max values. We do the opposite if descendingOrder is false. + __device__ __forceinline__ void init(bool descendingOrder) { + value = descendingOrder ? -FpInfinity::infinity() : FpInfinity::infinity(); + index = 0; + } +}; + +/// A binary reduction functor that CUB uses to perform the block reduce. This version is used when we want to find the max +/// value in a given range. +template +__device__ __forceinline__ TopK reduce_topk_max(const TopK& a, const TopK& b) { + return a.value > b.value ? a : b; +} + +/// A binary reduction functor that CUB uses to perform the block reduce. This version is used when we want to find the min +/// value in a given range. +template +__device__ __forceinline__ TopK reduce_topk_min(const TopK& a, const TopK& b) { + return a.value < b.value ? a : b; +} + + + +/** + This function performs the first phase of the topk. It finds the k largest elements within each BLOCKS_PER_ITEM_ blocks + for each item in a row and writes the indices to the topk_tmp_id_buf buffer and the values to the topk_tmp_val_buf. + + These intermediate values are further reduced in the second phase of this kernel. + + Note - This implementation modifies the input array in place but fixes it afterwards. This could be an issue if multiple + devices/streams need to read the same array but this shouldn't be an issue. + + Template params: + IndexType - the type of the array used to store the index values. This must be an integral type + + T - the type of input_array + + BLOCK_SIZE_ - The number of threads in a block + + BLOCKS_PER_ITEM_ - The number of blocks launches for each item in a row. eg. For a beam search topk, this could be the number + of blocks needed to process a beam. + + getRowOffsets - A boolean indicating whether we want the indices returned to be relative to the start of a row or relative to + the base pointer of the array. + + Function params: + input_array - The input matrix to the topk operator + + topk_tmp_id_buf - Stores the intermediate indicies for the next phase of the kernel. Must be at least #rows * items_per_row * k * BLOCKS_PER_ITEM_ + + topk_tmp_val_buf - Stores the intermediate values for the next phase of the kernel. Must be at least #rows * items_per_row * k * BLOCKS_PER_ITEM_ + + Note: The shape for the above items comes from the fact that each block launched finds the top k items for the given item in its row. + So, for each item in a row, all the BLOCKS_PER_ITEM_ blocks write out k values. These are further reduced in stage 2. + + k - the k in top k. Specifies that in each row, the k largest/smallest elements should be retrieved + + values_per_item - the number of values within each item in a given row. Eg. If a row has two items, and the row length is 100, the values_per_item would be 50. + For beam search, this would be the vocabulary size and the items_per_row would be the number of hypotheses in the beam. + + descendingOrder - If true, the k largest elements are returned. Otherwise, the k smallest elements are returned. +*/ +template +__global__ void topk_stage_1(T* input_array, + IndexType* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const int k, + const int values_per_item, + const int descendingOrder) { + + // Set up shared memory needed for CUB reductions + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Compute constants used within each block. + const IndexType tid = threadIdx.x; + const IndexType bid = blockIdx.x; + + // The row in the tmp array to write the topk elements for each BLOCKS_PER_ITEM_ blocks + // The tmp array is interpreted as have one row per item in this function. + const IndexType row_id = bid / BLOCKS_PER_ITEM_; + const IndexType block_lane = bid % BLOCKS_PER_ITEM_; // block id for an item + const T minimal = descendingOrder? -FpInfinity::infinity() : FpInfinity::infinity();; + + // Computes the index offset from the temp id and value arrays that each block should write its tmp output + const IndexType tmp_log_buf_index = row_id * values_per_item; + const IndexType tmp_topk_buf_index = row_id * BLOCKS_PER_ITEM_ * k + block_lane * k; + + // The partial topk result (reductions from global memory before performing block reduction) + TopK partial; + + // Finds the largest/smallest element within a block's given range then blanks out that value in the input array + // before starting the next iteration. The blanked out values are fixed after this loop. + for (int ite = 0; ite < k; ite++) { + partial.init(descendingOrder); + const IndexType threadStart = tid + block_lane * BLOCK_SIZE_; + + // This is needed to ensure the indices for the threads in each valid block starts in a valid range for that block. + if (threadStart < values_per_item) + partial.index = threadStart; + + // Each block constructs its partial result before performing a block level reduce. + #pragma unroll + for (IndexType elem_id = threadStart; elem_id < values_per_item; elem_id += BLOCK_SIZE_ * BLOCKS_PER_ITEM_) { + IndexType index = elem_id + tmp_log_buf_index; + descendingOrder? partial.updateIfLarger(input_array[index], index) : partial.updateIfSmaller(input_array[index], index); + } + + // Invoke CUB to perform the block level reduce. The reduction function depends on the descending order flag. + TopK total = BlockReduce(temp_storage).Reduce(partial, descendingOrder? reduce_topk_max: reduce_topk_min); + + // Wrtie the index and value to global memory and blank out the max/min value found from the input_array so it is not considered in the next + // iteration. + if (tid == 0) { + const int index = tmp_topk_buf_index + ite; + topk_tmp_id_buf[index] = getRowOffsets? total.index - tmp_log_buf_index : total.index; + topk_tmp_val_buf[index] = total.value; + // If we found a max, blank out the value in the log prob array before starting the next iteration. + // Otherwise, we don't need to issue a write since all prob values must have been T::min() + if (total.value != minimal) + input_array[total.index] = minimal; + } + __syncthreads(); + } + + // Update prob array with original values. + for (int beam = tid; beam < k; beam += BLOCK_SIZE_) { + const IndexType index = tmp_topk_buf_index + beam; + T val = topk_tmp_val_buf[index]; + // We only want to replace the value in the log prob array if a value was blanked out (we found a max). + // When a max isn't found, topk_tmp_val_buf[index] will be T::min() + if (val != minimal) { + IndexType k_idx = getRowOffsets? topk_tmp_id_buf[index] + tmp_log_buf_index : topk_tmp_id_buf[index]; + input_array[k_idx] = (T)topk_tmp_val_buf[index]; + } + } +} + +/** + Reduces the values in the tmp arrays to find the topk values in each row. The tmp arrays have items_per_row * k * BLOCKS_PER_ITEM_ entries for each + row of the input array. This kernel finds the k largest entries amount those values and writes the outputs to the top array, outVals array and outIndices + array as specified below. + + Template params: same meaning as topk_stage_1 + + Function params: + topk_tmp_id_buf - Stores the intermediate indicies for the next phase of the kernel. Must be at least #rows * items_per_row * k * BLOCKS_PER_ITEM_ + + topk_tmp_val_buf - Stores the intermediate values for the next phase of the kernel. Must be at least #rows * items_per_row * k * BLOCKS_PER_ITEM_ + + top (OPTIONAL: cam be NULL) - An array of structs of TopK objects. This exists so that the topk elements can be read back using one call to + cudaMemcpy. However, it exposes the user to some implementation details so it is optional. If the pointer is NULL, + it is ignored. + + outIndices (OPTIONAL: can be NULL) - An array of IndexType containing the index locations where the topk items are found. It is ignored if the pointer + is NULL + + outVals (OPTIONAL: can be NULL) - An array of IndexType containing the index locations where the topk items are found. It is ignored if the pointer + is NULL + + items_per_row - The number of items to be processed within each row of the input. (eg. For beam search, this would be the hypotheses per beam) + + k - the k in top k. Specifies that in each row, the k largest/smallest elements should be retrieved + + descendingOrder - If true, the k largest elements are returned. Otherwise, the k smallest elements are returned. +*/ +template +__global__ void topk_stage_2(const IndexType* __restrict topk_tmp_id_buf, + T* topk_tmp_val_buf, + TopK* top, + IndexType* outIndices, + T* outVals, + const int items_per_row, + const int k, + bool descendingOrder) { + + // Size of one row in the tmp array. + const int size = items_per_row * k * BLOCKS_PER_ITEM_; + + // Some constants needed for eahc block + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + const T minimal = descendingOrder? -FpInfinity::infinity() : FpInfinity::infinity();; + + // CUB reduction declarations + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Shared memory used to keep the topk items found by each block on the SM until the final write. + extern __shared__ char array[]; + T *s_val = topk_tmp_val_buf + batch_id * size; + TopK *topks = (TopK*)(array); + + // Partial reduction for topk (reducing across global before the blockwide reduce.) + TopK partial; + + // Finds the largest/smallest element within a block's given range then blanks out that value in the input array + // before starting the next iteration. The blanked out values are not fixed since we no longer care about them. + for (int ite = 0; ite < k; ite++) { + partial.init(descendingOrder); + + // First reduce across global mem if needed so that the max/min values per row reside in the block. + #pragma unroll + for (IndexType i = tid; i < size; i+= BLOCK_SIZE_) { + descendingOrder? partial.updateIfLarger(s_val[i], i) : partial.updateIfSmaller(s_val[i], i); + } + + // Use CUB to perform the blockwise reduction + TopK total = BlockReduce(temp_storage).Reduce(partial, descendingOrder? reduce_topk_max: reduce_topk_min); + + // Store the kth largest/smallest index, element pair in shared and blank out value stored in the global array before the next loop iteration. + if (tid == 0) { + topks[ite] = total; + s_val[total.index] = minimal; + } + __syncthreads(); + } + + // Now that we have all of the topk items in shared, write them out to global to the defined arrays. + for (int elt = tid; elt < k; elt += BLOCK_SIZE_) { + TopK beamOut; + IndexType indexInTmpValRow = topks[elt].index; + beamOut.index = topk_tmp_id_buf[batch_id * size + indexInTmpValRow]; + beamOut.value = topks[elt].value; + if (top) + top[batch_id * k + elt] = beamOut; + + if (outIndices) + outIndices[batch_id * k + elt] = beamOut.index; + + if (outVals) + outVals[batch_id * k + elt] = beamOut.value; + } +} + +// A helper for launching the kernels due to needed to set some template parameters +#define CASE_K(K,BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_ITEM_) \ + case K: \ + topk_stage_1<<>>( \ + input_array, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + k, values_per_item, descendingOrder); \ + topk_stage_2<<), stream>>>( \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + tops, \ + outIndices, \ + outVals, \ + items_per_row, \ + k, descendingOrder); \ + break; \ + +/** + This overload launches the cuda kernels needed to perform the topk operation. It returns the topk items in an array of structs in the tops array. + The benefit is that one call to cudaMemcpy is needed to read back the index, value pairs. However, it needs to expose the TopK struct externally to + achieve this. If possible, and if reading back to the host, it is recommended to use this version to reduce the host/device communication required. + + Template params: + IndexType - the type of the array used to store the index values. This must be an integral type + + T - the type of input_array + + getRowOffsets - A boolean indicating whether we want the indices returned to be relative to the start of a row or relative to + the base pointer of the array. + + The getRowOffsets template parameter is added so the topk implementation works with both nth_element.cu and the topk operator. + It is a template parameter since we know at compile time which version of topk we want to call. This flag can be removed whenever nth + element.cu is removed. When this flag is true, the indices returns are the offsets within the row. When the flag is false, the indices + returned are offset from the base pointer. + + Function params: + input_array - The input matrix to the topk operator + + topk_tmp_id_buf - Stores the intermediate indicies for the next phase of the kernel. Must be at least rows * items_per_row * k * MAX_BLOCKS_PER_ITEM + + topk_tmp_val_buf - Stores the intermediate values for the next phase of the kernel. Must be at least #rows * items_per_row * k * MAX_BLOCKS_PER_ITEM + + tops - An array on TopK structs where the final topk values will be written. This should be of shape (rows, k). + + rows - The number of rows in the input array. + + items_per_row - The number of items in a row of the input array. For a normal topk, this would be one and values per item would be #cols. For beam search + topk, this should be the number of hypotheses in one batch input assuming the number of rows is set to the batch_size. + + k - the k in top k. Specifies that in each row, the k largest/smallest elements should be retrieved + + values_per_item - the number of values within each item in a given row. Eg. For a normal topk, this is #cols. For beam search topk, this could be the vocab_size + provided that the items_per_row is set to the # of hypotheses and the rows is set to the current batch size. + + descendingOrder - If true, the k largest elements are returned. Otherwise, the k smallest elements are returned. + + stream - the stream in which this operation should be run. +*/ +template +void topK_kernelLauncher(T* input_array, + IndexType* topk_tmp_id_buf, + T* topk_tmp_val_buf, + TopK* tops, + const int rows, + const int items_per_row, + const int k, + const int values_per_item, + bool descendingOrder, + cudaStream_t stream) { + + IndexType* outIndices = nullptr; + T* outVals = nullptr; + switch(k) { + CASE_K(1,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(2,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(4,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(6,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(8,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(10,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(16,128,128,5); + CASE_K(32,256,128,1); + CASE_K(64,256,256,1); + default: + topk_stage_1<<>>(input_array, + topk_tmp_id_buf, + topk_tmp_val_buf, + k, + values_per_item, + descendingOrder); + + topk_stage_2<<), stream>>>(topk_tmp_id_buf, + topk_tmp_val_buf, + tops, + outIndices, + outVals, + items_per_row, + k, + descendingOrder); + break; + } +} + +/** + This overload launches the cuda kernels needed to perform the topk operation. It returns the topk items in the outIndices and outVals arrays. This + needs two cudaMemcpy calls to read the topk values from the device to the host but it does not expose the topk struct. It is recommened to use the + other overload if possible and reading back to host memory to reduce the communication needed between the host and device. + + Template params: same as previous overload. + + Function params: same as previous overload except: + + outIndices - The array to write the topk indices. This should be (rows, k) + + outVals - The array to write the topk values. This should be (rows, k) + +*/ +template +void topK_kernelLauncher(T* input_array, + IndexType* topk_tmp_id_buf, + T* topk_tmp_val_buf, + IndexType* outIndices, + T* outVals, + const int rows, + const int items_per_row, + const int k, + const int values_per_item, + bool descendingOrder, + cudaStream_t stream) { + + TopK* tops = nullptr; + switch(k) { + CASE_K(1,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(2,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(4,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(6,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(8,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(10,128,128,MAX_BLOCKS_PER_ITEM); + CASE_K(16,128,128,5); + CASE_K(32,256,128,1); + CASE_K(64,256,256,1); + default: + topk_stage_1<<>>(input_array, + topk_tmp_id_buf, + topk_tmp_val_buf, + k, + values_per_item, + descendingOrder); + + topk_stage_2<<), stream>>>(topk_tmp_id_buf, + topk_tmp_val_buf, + tops, + outIndices, + outVals, + items_per_row, + k, + descendingOrder); + break; + } +} \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b47663b4e..d0c8043cf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,3 +1,5 @@ +add_definitions(-DCUB_IGNORE_DEPRECATED_CPP_DIALECT=1) +add_definitions(-DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT=1) add_subdirectory(3rd_party) include_directories(.) diff --git a/src/tensors/gpu/topk.cu b/src/tensors/gpu/topk.cu index 94256fb7a..98f641ef2 100644 --- a/src/tensors/gpu/topk.cu +++ b/src/tensors/gpu/topk.cu @@ -1,8 +1,10 @@ +#include "common/definitions.h" #include "tensors/tensor_operators.h" #include "tensors/gpu/cuda_helpers.h" #include "tensors/allocator.h" #include +#include "3rd_party/topk.cuh" // GPU implementation of proper Marian top-k operator for TopkNodeOp // This file contains a lot of code-duplicaton with src/translator/nth_element.cu @@ -13,309 +15,6 @@ namespace marian { namespace gpu { -const int MAX_BINS = 500; -const int BLOCK_SIZE = 512; - -#define UNROLL_MAXARG_LOOP(n, max) \ - if(tid < (n) && tid + (n) < (max)) { \ - if(sharedValues[tid + (n)] > sharedValues[tid]) { \ - sharedIndices[tid] = sharedIndices[tid + (n)]; \ - sharedValues[tid] = sharedValues[tid + (n)]; \ - } \ - } - -// finds maximum element (first step) -template -__global__ void gMaxElement(IndexType* binIndices, // out: top-k positions - T* binValues, // out: top-k scores - const T* inValues, // this is the probs array, only one with type float or half - int rows, // we iterate over this many rows, row-major layout - int cols, // a row has that many columns, row-major layout - float minimal, // minimal is the smallest possible value. For simplicity we assume we look for the maxmimum. - bool descending) // This will be the largest possible value if the order is reversed (i.e. we look for the minimum). -{ - extern __shared__ float sharedValues[]; - __shared__ IndexType sharedIndices[BLOCK_SIZE]; - - // id of current thread within block - int tid = threadIdx.x; - - float flip = descending ? 1.f : -1.f; - - // Roll over every row in row-major 2D representation of the data - for(int rowIdx = 0; rowIdx < rows; ++rowIdx) { - int begin = rowIdx * cols; // start index of a row - int end = rowIdx * cols + cols; // end index of a row - - // We look at at most blockDim.x * 2 = 1024 values within a block, i.e. each thread reduces two values. - // Here we set the position to begin + blockId * 1024 + threadId. If a row has more values we - // partition the row according to blocks of 1024 values. - int i = begin + blockIdx.x * (blockDim.x * 2) + tid; - - // Initialize shared values to minimal value. - sharedValues[tid] = minimal; - - // Do first set of comparisons outside loop, saves one iteration. - if(i + blockDim.x < end) { // Are we in a position for which we can access and compare two values in a row partition (shifted by block size)? - // yes, hence compare: - float a = flip * (float)inValues[i]; // value from first half of row parition for this block - float b = flip * (float)inValues[i + blockDim.x]; // value from second half of row partition for this block - if(a > b) { // just a max - sharedIndices[tid] = i; - sharedValues[tid] = a; - } else { - sharedIndices[tid] = i + blockDim.x; - sharedValues[tid] = b; - } - } else if(i < end) { // Are we instead in a position that has access to one value in the row partition (shifting by block size would be out of bounds)? - // Yes, hence save the current value and index as new max, no need to compare. - sharedIndices[tid] = i; - sharedValues[tid] = flip * (float)inValues[i]; - } // nothing else to do here - - // We move to the next set of 1024 values shifted by block size times number of blocks - // and look at two of them according to thread id. - while(i + 2 * gridDim.x * blockDim.x < end) { - i += 2 * gridDim.x * blockDim.x; - - // Check if first value is larger than what we have seen so far - float a = flip * (float)inValues[i]; - if(a > sharedValues[tid]) { - // Yes, hence save index and value - sharedIndices[tid] = i; - sharedValues[tid] = a; - } - - // Check if second value is larger than what we have seen so far - if(i + blockDim.x < end) { - float b = flip * (float)inValues[i + blockDim.x]; - if(b > sharedValues[tid]) { - // Yes, hence save index and value - sharedIndices[tid] = i + blockDim.x; - sharedValues[tid] = b; - } - } - } - - // We are done with the first sweep and have populated shared memory, time to wait for the other threads and reduce it all - __syncthreads(); - - // Reduce over shared memory, here per loop until we hit the last 32 unreduced elements - for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { - if(tid < s && tid + s < end) { - if(sharedValues[tid + s] > sharedValues[tid]) { - // keep the max - sharedIndices[tid] = sharedIndices[tid + s]; - sharedValues[tid] = sharedValues[tid + s]; - } - } - __syncthreads(); - } - - // Reduce over shared memory, here per unrolled code for powers of 2 lower equal 32. - // Because we are at 32 (warp size) the threads run in lock-step and we can abandon syncing. - UNROLL_MAXARG_LOOP(32, end); - UNROLL_MAXARG_LOOP(16, end); - UNROLL_MAXARG_LOOP(8, end); - UNROLL_MAXARG_LOOP(4, end); - UNROLL_MAXARG_LOOP(2, end); - UNROLL_MAXARG_LOOP(1, end); - - // OK, we are done with the reduction and in the first thread - if(tid == 0) { - // assign the final maximal value to the bin, one bin per row and block - binIndices[rowIdx * gridDim.x + blockIdx.x] = sharedIndices[0]; // [rows, num_blocks] - binValues[rowIdx * gridDim.x + blockIdx.x] = sharedValues[0]; // [rows, num_blocks] - } - __syncthreads(); - } -} - -// This runs after the function above, we now have the maximum value per row and block and can look further -// for the top-k results. As above we pretend this does only maximum search. -// This runs restricted to one row (one row per block) -template -__global__ void gMaxElementUpdate(IndexType* binIndices, // memory for bin indices - T* binValues, // memory for bin costs - IndexType* outIndices, // result indices - T* outValues, // result costs - T* inValues, // should work well enough with half, uses float everywhere else - const int cols, // size of continous memory we search over - const int K, // how many top-K elements? - int numBlocks, // number of blocks/bins used in above function (per row) - float minimal, // value for minimal element - bool descending) -{ - extern __shared__ float sharedValues[]; - __shared__ int sharedIndices[BLOCK_SIZE]; - __shared__ float bestBinCost; - __shared__ int bestBinCostIdx; - - const int tid = threadIdx.x; - - float flip = descending ? 1.f : -1.f; - - // we only look at one row in this kernel - const int rowIdx = blockIdx.x; // index of the row corresponds to block index - const int begin = rowIdx * cols; // start offset for this row relative to inValues tensor start - const int end = rowIdx * cols + cols; // end offset for this row relative to inValues tensor start - - int num_bins = numBlocks; // why not just use numBlocks? - - // iterate over top-k results - for(int k = 0; k < K; ++k) { - - int kthOutIdx = rowIdx * K + k; // offset into output tensor relative to outIndices/outValues tensor start - int i = tid; - - sharedValues[tid] = minimal; // initialize to smallest value, everything else will be larger - - // as in the function above, the code here does a tree reduction over shared memory to find the single maximum element - if(i + blockDim.x < num_bins) { - float a = binValues[rowIdx * numBlocks + i]; - float b = binValues[rowIdx * numBlocks + i + blockDim.x]; - if(a > b) { - sharedValues[tid] = a; - sharedIndices[tid] = i; - } else { - sharedValues[tid] = b; - sharedIndices[tid] = i + blockDim.x; - } - } else if(i < num_bins) { - sharedValues[tid] = binValues[rowIdx * numBlocks + i]; - sharedIndices[tid] = i; - } - - while(i + 2 * blockDim.x < num_bins) { - i += 2 * blockDim.x; - - float a = binValues[rowIdx * numBlocks + i]; - if(a > sharedValues[tid]) { - sharedValues[tid] = a; - sharedIndices[tid] = i; - } - - if(i + blockDim.x < num_bins) { - float b = binValues[rowIdx * numBlocks + i + blockDim.x]; - if(b > sharedValues[tid]) { - sharedValues[tid] = b; - sharedIndices[tid] = i + blockDim.x; - } - } - } - - __syncthreads(); - - for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { - if(tid < s && tid + s < num_bins) { - if(sharedValues[tid + s] > sharedValues[tid]) { - sharedValues[tid] = sharedValues[tid + s]; - sharedIndices[tid] = sharedIndices[tid + s]; - } - } - __syncthreads(); - } - - UNROLL_MAXARG_LOOP(32, num_bins); - UNROLL_MAXARG_LOOP(16, num_bins); - UNROLL_MAXARG_LOOP(8, num_bins); - UNROLL_MAXARG_LOOP(4, num_bins); - UNROLL_MAXARG_LOOP(2, num_bins); - UNROLL_MAXARG_LOOP(1, num_bins); - - if(tid == 0) { - bestBinCost = sharedValues[0]; - bestBinCostIdx = rowIdx * numBlocks + sharedIndices[0]; - - inValues[binIndices[bestBinCostIdx]] = flip * minimal; // this is restored in the last lines of this function - - outIndices[kthOutIdx] = binIndices[bestBinCostIdx] - begin; // relative to beginning of row hence substract `begin` - outValues[kthOutIdx] = flip * bestBinCost; // undo flip by flipping again - } - - __syncthreads(); - - // Second part of the algorithm, why it that not replacing the first function call?? - // Also shouldn't we skip here if k == K - 1? - - // After marking the previously largest element with "flip * minimal" we populate again - // shared memory with the largest element as in gMaxElement(...) - - if(k < K - 1) { - i = begin + (bestBinCostIdx - rowIdx * numBlocks) * (blockDim.x * 2) + tid; - const int dist = num_bins * 2 * blockDim.x; - - sharedValues[tid] = minimal; - - if(i + blockDim.x < end) { - float a = flip * (float)inValues[i]; - float b = flip * (float)inValues[i + blockDim.x]; - if(a > b) { - sharedIndices[tid] = i; - sharedValues[tid] = a; - } else { - sharedIndices[tid] = i + blockDim.x; - sharedValues[tid] = b; - } - } else if(i < end) { - sharedIndices[tid] = i; - sharedValues[tid] = flip * (float)inValues[i]; - } - - while(i + dist < end) { - i += dist; - - float a = flip * (float)inValues[i]; - if(a > sharedValues[tid]) { - sharedIndices[tid] = i; - sharedValues[tid] = a; - } - - if(i + blockDim.x < end) { - float b = flip * (float)inValues[i + blockDim.x]; - if(b > sharedValues[tid]) { - sharedIndices[tid] = i + blockDim.x; - sharedValues[tid] = b; - } - } - } - - __syncthreads(); - - for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { - if(tid < s && tid + s < end) { - if(sharedValues[tid + s] > sharedValues[tid]) { - sharedIndices[tid] = sharedIndices[tid + s]; - sharedValues[tid] = sharedValues[tid + s]; - } - } - __syncthreads(); - } - - UNROLL_MAXARG_LOOP(32, end); - UNROLL_MAXARG_LOOP(16, end); - UNROLL_MAXARG_LOOP(8, end); - UNROLL_MAXARG_LOOP(4, end); - UNROLL_MAXARG_LOOP(2, end); - UNROLL_MAXARG_LOOP(1, end); - - if(tid == 0) { - binIndices[bestBinCostIdx] = sharedIndices[0]; - binValues[bestBinCostIdx] = sharedValues[0]; - } - __syncthreads(); - } - } - - // final operation to restore blanked-out input values. They were blanked out for marking - // already found values. Since we want input values to be invariant we restore here. - // @TODO: The lack of constness here might be a problem for concurrent processing (which we currently don't have) - for(int k = tid; k < K; k += blockDim.x) { - int kthOutIdx = rowIdx * K + k; - inValues[begin + outIndices[kthOutIdx]] = outValues[kthOutIdx]; - } -} - void TopK(Tensor outVal, Tensor outInd, Ptr allocator, const Tensor in, int k, int axis, bool descending) { ABORT_IF(axis != in->shape().size() - 1, "Currently only works for last axis"); @@ -330,45 +29,39 @@ void TopK(Tensor outVal, Tensor outInd, Ptr allocator, const Tensor i ABORT_IF(k > cols, "Cannot select more than {} elements for axis {}", cols, axis); - float minimal = NumericLimits(in->type()).lowest; // lowest if looking for max + const int beams = 1; + const int tempElts = rows * beams * beams * MAX_BLOCKS_PER_ITEM; - const int numBlocks = std::min(MAX_BINS, int(cols / (2 * BLOCK_SIZE)) + int(cols % (2 * BLOCK_SIZE) != 0)); - auto tempMemInd = allocator->alloc(rows * numBlocks); + auto tempMemInd = allocator->alloc(tempElts); MemoryPiece::PtrType tempMemVal; if(in->type() == Type::float32) { - tempMemVal = allocator->alloc(rows * numBlocks); - // first find the maximum value per row and block and save indices and values to temporary memory - gMaxElement<<>>( - tempMemInd->data(), tempMemVal->data(), - in->data(), rows, cols, minimal, descending); - gMaxElementUpdate<<>>( - tempMemInd->data(), tempMemVal->data(), - outInd->data(), outVal->data(), - in->data(), cols, k, numBlocks, minimal, descending); + tempMemVal = allocator->alloc(tempElts); + topK_kernelLauncher(in->data(), + tempMemInd->data(), + tempMemVal->data(), + outInd->data(), + outVal->data(), + rows, + 1, // This is the beam size. This is set to 1 since we "trick" the existing implementation to treat a row as a beam + k, + cols, + descending, + cudaStreamPerThread); #if COMPILE_FP16 } else if(in->type() == Type::float16) { - tempMemVal = allocator->alloc<__half>(rows * numBlocks); - // first find the maximum value per row and block and save indices and values to temporary memory - gMaxElement<<>>( - tempMemInd->data(), tempMemVal->data<__half>(), - in->data<__half>(), rows, cols, minimal, descending); - gMaxElementUpdate<<>>( - tempMemInd->data(), tempMemVal->data<__half>(), - outInd->data(), outVal->data<__half>(), - in->data<__half>(), cols, k, numBlocks, minimal, descending); + tempMemVal = allocator->alloc<__half>(tempElts); + topK_kernelLauncher(in->data<__half>(), + tempMemInd->data(), + tempMemVal->data<__half>(), + outInd->data(), + outVal->data<__half>(), + rows, + 1, // This is the beam size. This is set to 1 since we "trick" the existing implementation to treat a row as a beam + k, + cols, + descending, + cudaStreamPerThread); #endif } else { ABORT("Topk not implemented for type {}", in->type()); diff --git a/src/translator/nth_element.cu b/src/translator/nth_element.cu index e8786ee79..27053c95e 100644 --- a/src/translator/nth_element.cu +++ b/src/translator/nth_element.cu @@ -5,276 +5,14 @@ #include +#include "common/definitions.h" #include "translator/nth_element.h" +#include "3rd_party/topk.cuh" #include #include "tensors/gpu/cuda_helpers.h" namespace marian { - -#define UNROLL_MAXARG_LOOP(n, max) \ - if(tid < (n) && tid + (n) < (max)) { \ - if(sdata[tid + (n)] > sdata[tid]) { \ - sdata[tid] = sdata[tid + (n)]; \ - indices[tid] = indices[tid + (n)]; \ - } \ - } - -template -__global__ void gMaxElement(float* d_out, - int* d_ind, - T* d_in, // this is the probs array, only one with type float or half - int numBatches, - int* batchFirstElementIdxs, - float disabledPathScore) // disabledPathScore is used to blank out found values, type-dependent -{ - extern __shared__ float sdata[]; - __shared__ int indices[512]; - - int tid = threadIdx.x; - - for(int batchIdx = 0; batchIdx < numBatches; ++batchIdx) { - int begin = batchFirstElementIdxs[batchIdx]; - int end = batchFirstElementIdxs[batchIdx + 1]; - - int i = begin + blockIdx.x * (blockDim.x * 2) + tid; - - sdata[tid] = disabledPathScore; - - if(i < end) { - sdata[tid] = (float)d_in[i]; - indices[tid] = i; - } - - if(i + blockDim.x < end) { - float a = (float)d_in[i]; - float b = (float)d_in[i + blockDim.x]; - if(a > b) { - sdata[tid] = a; - indices[tid] = i; - } else { - sdata[tid] = b; - indices[tid] = i + blockDim.x; - } - } - - while(i + 2 * gridDim.x * blockDim.x < end) { - i += 2 * gridDim.x * blockDim.x; - - float a = (float)d_in[i]; - if(a > sdata[tid]) { - sdata[tid] = a; - indices[tid] = i; - } - - if(i + blockDim.x < end) { - float b = (float)d_in[i + blockDim.x]; - if(b > sdata[tid]) { - sdata[tid] = b; - indices[tid] = i + blockDim.x; - } - } - } - - __syncthreads(); - - for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { - if(tid < s && tid + s < end) { - if(sdata[tid + s] > sdata[tid]) { - sdata[tid] = sdata[tid + s]; - indices[tid] = indices[tid + s]; - } - } - __syncthreads(); - } - - UNROLL_MAXARG_LOOP(32, end); - UNROLL_MAXARG_LOOP(16, end); - UNROLL_MAXARG_LOOP(8, end); - UNROLL_MAXARG_LOOP(4, end); - UNROLL_MAXARG_LOOP(2, end); - UNROLL_MAXARG_LOOP(1, end); - - if(tid == 0) { - d_out[blockIdx.x + batchIdx * gridDim.x] = sdata[0]; - d_ind[blockIdx.x + batchIdx * gridDim.x] = indices[0]; - } - __syncthreads(); - } -} - -template -__global__ void gMaxElementUpdate(float* binCosts, - int* binIdxs, - T* probs, // should work well enough with half, uses float everywhere else - int* batchFirstElements, - float* outCosts, - int* outIdxs, - int* cumulativeBeamSizes, - int NUM_BLOCKS, - float disabledPathScore) { - extern __shared__ float sdata[]; - __shared__ int indices[512]; - __shared__ float bestBinCost; - __shared__ int bestBinCostIdx; - - const int tid = threadIdx.x; - const int batchIdx = blockIdx.x; - const int N = batchFirstElements[batchIdx + 1] - batchFirstElements[batchIdx]; - int num_bins = int(N / (2 * 512)) + int(N % (2 * 512) != 0); - if(num_bins > 500) { - num_bins = 500; - } - - for(int pos = cumulativeBeamSizes[batchIdx]; - pos < cumulativeBeamSizes[batchIdx + 1]; - ++pos) { - int i = tid; - - sdata[tid] = disabledPathScore; - - if(i < num_bins) { - sdata[tid] = binCosts[batchIdx * NUM_BLOCKS + i]; - indices[tid] = i; - } - - if(i + blockDim.x < num_bins) { - float a = binCosts[batchIdx * NUM_BLOCKS + i]; - float b = binCosts[batchIdx * NUM_BLOCKS + i + blockDim.x]; - if(a > b) { - sdata[tid] = a; - indices[tid] = i; - } else { - sdata[tid] = b; - indices[tid] = i + blockDim.x; - } - } - - while(i + 2 * blockDim.x < num_bins) { - i += 2 * blockDim.x; - - float a = binCosts[batchIdx * NUM_BLOCKS + i]; - if(a > sdata[tid]) { - sdata[tid] = a; - indices[tid] = i; - } - - if(i + blockDim.x < num_bins) { - float b = binCosts[batchIdx * NUM_BLOCKS + i + blockDim.x]; - if(b > sdata[tid]) { - sdata[tid] = b; - indices[tid] = i + blockDim.x; - } - } - } - - __syncthreads(); - - for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { - if(tid < s && tid + s < num_bins) { - if(sdata[tid + s] > sdata[tid]) { - sdata[tid] = sdata[tid + s]; - indices[tid] = indices[tid + s]; - } - } - __syncthreads(); - } - - UNROLL_MAXARG_LOOP(32, num_bins); - UNROLL_MAXARG_LOOP(16, num_bins); - UNROLL_MAXARG_LOOP(8, num_bins); - UNROLL_MAXARG_LOOP(4, num_bins); - UNROLL_MAXARG_LOOP(2, num_bins); - UNROLL_MAXARG_LOOP(1, num_bins); - - if(tid == 0) { - bestBinCost = sdata[0]; - bestBinCostIdx = batchIdx * NUM_BLOCKS + indices[0]; - - probs[binIdxs[bestBinCostIdx]] = disabledPathScore; - - outIdxs[pos] = binIdxs[bestBinCostIdx]; - outCosts[pos] = bestBinCost; - } - - __syncthreads(); - - i = batchFirstElements[batchIdx] - + (bestBinCostIdx - batchIdx * NUM_BLOCKS) * (blockDim.x * 2) + tid; - const int dist = num_bins * 2 * blockDim.x; - - sdata[tid] = disabledPathScore; - - if(i < batchFirstElements[batchIdx + 1]) { - sdata[tid] = (float)probs[i]; - indices[tid] = i; - } - - if(i + blockDim.x < batchFirstElements[batchIdx + 1]) { - float a = (float)probs[i]; - float b = (float)probs[i + blockDim.x]; - if(a > b) { - sdata[tid] = a; - indices[tid] = i; - } else { - sdata[tid] = b; - indices[tid] = i + blockDim.x; - } - } - - while(i + dist < batchFirstElements[batchIdx + 1]) { - i += dist; - - float a = (float)probs[i]; - if(a > sdata[tid]) { - sdata[tid] = a; - indices[tid] = i; - } - - if(i + blockDim.x < batchFirstElements[batchIdx + 1]) { - float b = (float)probs[i + blockDim.x]; - if(b > sdata[tid]) { - sdata[tid] = b; - indices[tid] = i + blockDim.x; - } - } - } - - __syncthreads(); - - for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { - if(tid < s && tid + s < batchFirstElements[batchIdx + 1]) { - if(sdata[tid + s] > sdata[tid]) { - sdata[tid] = sdata[tid + s]; - indices[tid] = indices[tid + s]; - } - } - __syncthreads(); - } - - UNROLL_MAXARG_LOOP(32, batchFirstElements[batchIdx + 1]); - UNROLL_MAXARG_LOOP(16, batchFirstElements[batchIdx + 1]); - UNROLL_MAXARG_LOOP(8, batchFirstElements[batchIdx + 1]); - UNROLL_MAXARG_LOOP(4, batchFirstElements[batchIdx + 1]); - UNROLL_MAXARG_LOOP(2, batchFirstElements[batchIdx + 1]); - UNROLL_MAXARG_LOOP(1, batchFirstElements[batchIdx + 1]); - - if(tid == 0) { - binCosts[bestBinCostIdx] = sdata[0]; - binIdxs[bestBinCostIdx] = indices[0]; - } - __syncthreads(); - } -} - -__global__ void gGetValueByKey(float* d_in, float* d_out, int* indeces, int n) { - int tid = threadIdx.x + blockDim.x * blockIdx.x; - if(tid < n) { - int index = indeces[tid]; - d_out[tid] = d_in[index]; - } -} - class NthElementGPU { public: NthElementGPU() = delete; @@ -284,82 +22,38 @@ public: size_t maxBatchSize, DeviceId deviceId) : deviceId_(deviceId), - maxBeamSize_(maxBeamSize), maxBatchSize_(maxBatchSize), - NUM_BLOCKS(std::min( - 500, - int(maxBeamSize* MAX_VOCAB_SIZE / (2 * BLOCK_SIZE)) - + int(maxBeamSize* MAX_VOCAB_SIZE % (2 * BLOCK_SIZE) != 0))) { + maxBeamSize_(maxBeamSize), maxBatchSize_(maxBatchSize) { // std::cerr << "NthElement::NthElement" << std::endl; cudaSetDevice(deviceId_.no); - CUDA_CHECK(cudaMalloc((void**)&d_ind, maxBatchSize * NUM_BLOCKS * sizeof(int))); - CUDA_CHECK(cudaMalloc((void**)&d_out, maxBatchSize * NUM_BLOCKS * sizeof(float))); - - CUDA_CHECK(cudaMalloc((void**)&d_res_idx, maxBatchSize * maxBeamSize * sizeof(int))); - CUDA_CHECK(cudaMalloc((void**)&d_res, maxBatchSize * maxBeamSize * sizeof(float))); - - CUDA_CHECK(cudaHostAlloc((void**)&h_res, maxBeamSize * maxBatchSize * sizeof(float), cudaHostAllocDefault)); - CUDA_CHECK(cudaHostAlloc((void**)&h_res_idx, maxBeamSize * maxBatchSize * sizeof(int), cudaHostAllocDefault)); - - CUDA_CHECK(cudaMalloc((void**)&d_breakdown, maxBeamSize * sizeof(float))); - CUDA_CHECK(cudaMalloc((void**)&d_batchPosition, (maxBatchSize + 1) * sizeof(int))); - CUDA_CHECK(cudaMalloc((void**)&d_cumBeamSizes, (maxBatchSize + 1) * sizeof(int))); + const int tempElts = maxBatchSize * maxBeamSize * maxBeamSize * MAX_BLOCKS_PER_ITEM; + CUDA_CHECK(cudaMalloc((void**)&topk_tmp_id_buf, tempElts * sizeof(IndexType))); + CUDA_CHECK(cudaMalloc((void**)&topk_tmp_val_buf, tempElts * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)&tops, maxBatchSize * maxBeamSize * sizeof(TopK))); + CUDA_CHECK(cudaHostAlloc((void**)&topsHost, maxBeamSize * maxBatchSize * sizeof(TopK), cudaHostAllocDefault)); } ~NthElementGPU() { // No CUDA error checking as this is a destructor and we cannot do anything about errors anyway. cudaSetDevice(deviceId_.no); - cudaFree(d_cumBeamSizes); - cudaFree(d_batchPosition); - cudaFree(d_breakdown); - cudaFreeHost(h_res_idx); - cudaFreeHost(h_res); - cudaFree(d_res); - cudaFree(d_res_idx); - cudaFree(d_out); - cudaFree(d_ind); + cudaFree(topk_tmp_id_buf); + cudaFree(topk_tmp_val_buf); + cudaFree(tops); + cudaFreeHost(topsHost); } private: template - void selectNBest(T* probs, - const std::vector& batchFirstElementIdxs, - const std::vector& cumulativeBeamSizes, - float disabledPathScore) { - + void selectNBest(T* probs, + const int batchSize, + const int hypsPerBeam, // current hypotheses in each beam + const int beamWidth, // k + const int vocabSize) { cudaSetDevice(deviceId_.no); - CUDA_CHECK(cudaMemcpyAsync(d_batchPosition, - batchFirstElementIdxs.data(), - batchFirstElementIdxs.size() * sizeof(int), - cudaMemcpyHostToDevice, - /* stream_ */ 0)); - CUDA_CHECK(cudaMemcpyAsync(d_cumBeamSizes, - cumulativeBeamSizes.data(), - cumulativeBeamSizes.size() * sizeof(int), - cudaMemcpyHostToDevice, - /* stream_ */ 0)); - const int numBatches = batchFirstElementIdxs.size() - 1; - - gMaxElement<<>>( - d_out, d_ind, probs, numBatches, d_batchPosition, disabledPathScore); - - gMaxElementUpdate<<>>(d_out, - d_ind, - probs, - d_batchPosition, - d_res, - d_res_idx, - d_cumBeamSizes, - NUM_BLOCKS, - disabledPathScore); + topK_kernelLauncher(probs, topk_tmp_id_buf, (T*)topk_tmp_val_buf, (TopK*)tops, + batchSize, hypsPerBeam, beamWidth, vocabSize, true, 0); } public: @@ -373,68 +67,47 @@ public: const auto vocabSize = scores->shape()[-1]; const auto inputN = scores->shape()[-2]; const auto dimBatch = scores->shape()[-4]; + ABORT_IF(inputN != (isFirst ? 1 : N), "Input tensor has wrong beam dim??"); // @TODO: Remove isFirst argument altogether ABORT_IF(vocabSize > MAX_VOCAB_SIZE, "GetNBestList(): actual vocab size {} exceeds MAX_VOCAB_SIZE of {}", vocabSize, MAX_VOCAB_SIZE); ABORT_IF(dimBatch > maxBatchSize_, "GetNBestList(): actual batch size {} exceeds initialization parameter {}", dimBatch, maxBatchSize_); ABORT_IF(std::max(N, (size_t)inputN) > maxBeamSize_, "GetNBestList(): actual beam size {} exceeds initialization parameter {}", N, maxBeamSize_); - const std::vector beamSizes(dimBatch, N); - std::vector cumulativeBeamSizes(beamSizes.size() + 1, 0); - std::vector batchFirstElementIdxs(beamSizes.size() + 1, 0); - - for(size_t batchIdx = 0; batchIdx < beamSizes.size(); ++batchIdx) { -#if 1 - cumulativeBeamSizes[batchIdx + 1] = (batchIdx + 1) * (int)N; - batchFirstElementIdxs[batchIdx + 1] += (batchIdx + 1) * inputN * vocabSize; - ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != cumulativeBeamSizes[batchIdx] + (int)N, "cumulativeBeamSizes wrong??"); - ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??"); -#else - cumulativeBeamSizes[batchIdx + 1] = cumulativeBeamSizes[batchIdx] + beamSizes[batchIdx]; - ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != (batchIdx + 1) * N, "cumulativeBeamSizes wrong??"); - batchFirstElementIdxs[batchIdx + 1] - += ((isFirst) ? (batchIdx + 1) : cumulativeBeamSizes[batchIdx + 1]) * vocabSize; - ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??"); -#endif - } - if(scores->type() == Type::float32) { - float disabledPathScore = NumericLimits(scores->type()).lowest; - selectNBest(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes, disabledPathScore); + selectNBest(scores->data(), dimBatch, inputN, N, vocabSize); + getPairs(dimBatch * N, outKeys, outCosts); #if COMPILE_FP16 } else if(scores->type() == Type::float16) { - float disabledPathScore = NumericLimits(scores->type()).lowest; - selectNBest(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes, disabledPathScore); + selectNBest(scores->data(), dimBatch, inputN, N, vocabSize); + getPairs(dimBatch * N, outKeys, outCosts); #endif } else { ABORT("getNBestList not implemented for type {}", scores->type()); } - getPairs(dimBatch * N, outKeys, outCosts); - ABORT_IF(cumulativeBeamSizes.back() != dimBatch * N, "cumulativeBeamSizes.back() wrong??"); + + ABORT_IF(outKeys.size() != dimBatch * N, "Expected {} but got {} values during topk", outKeys.size(), dimBatch * N); } private: - void getPairs(size_t number, + template + void getPairs(size_t numElts, std::vector& outKeys, std::vector& outValues) { cudaSetDevice(deviceId_.no); - CUDA_CHECK(cudaMemcpyAsync(h_res, - d_res, - number * sizeof(float), - cudaMemcpyDeviceToHost, - /* stream_ */ 0)); - CUDA_CHECK(cudaMemcpyAsync(h_res_idx, - d_res_idx, - number * sizeof(int), + TopK* topsHostCasted = (TopK*)topsHost; + + CUDA_CHECK(cudaMemcpyAsync(topsHostCasted, + tops, + numElts * sizeof(TopK), cudaMemcpyDeviceToHost, /* stream_ */ 0)); - cudaStreamSynchronize(/* stream_ */ 0); - for(size_t i = 0; i < number; ++i) { - outKeys.push_back(h_res_idx[i]); - outValues.push_back(h_res[i]); - } + CUDA_CHECK(cudaStreamSynchronize(/* stream_ */ 0)); - //lastN = number; + for(size_t i = 0; i < numElts; ++i) { + outKeys.push_back(topsHostCasted[i].index); + outValues.push_back((float)topsHostCasted[i].value); + } } DeviceId deviceId_; @@ -443,22 +116,10 @@ private: size_t maxBeamSize_; size_t maxBatchSize_; - const int BLOCK_SIZE = 512; - const int NUM_BLOCKS; - - int* d_ind; // [maxBatchSize * NUM_BLOCKS] - float* d_out; // [maxBatchSize * NUM_BLOCKS] - - int* d_res_idx; // [maxBatchSize * maxBeamSize] - float* d_res; // [maxBatchSize * maxBeamSize] - - int* h_res_idx; // [maxBeamSize * maxBatchSize] - float* h_res; // [maxBeamSize * maxBatchSize] - - float* d_breakdown; // [maxBeamSize] - int* d_batchPosition; // [maxBatchSize + 1] - int* d_cumBeamSizes; // [maxBatchSize + 1] - //size_t lastN; + IndexType* topk_tmp_id_buf; // [maxBatchSize * maxBeamSize, maxBeamSize * MAX_BLOCKS_PER_BEAM] + float* topk_tmp_val_buf; // [maxBatchSize * maxBeamSize, maxBeamSize * MAX_BLOCKS_PER_BEAM] + TopK* tops; // [maxBatchSize, maxBeamSize] + TopK* topsHost; // [maxBatchSize, maxBeamSize] }; // factory function